Source code for sicor.Tools.cB.cB

import logging
import numpy as np

from sicor.Mask import S2Mask
from sicor.Tools import majority_mask_filter


[docs] class S2cB(object): def __init__(self, cb_clf, mask_legend, clf_to_col, processing_tiles=11, logger=None): self.logger = logger or logging.getLogger(__name__) self.cb_clf = cb_clf self.mask_legend = mask_legend self.clf_to_col = clf_to_col self.processing_tiles = processing_tiles self.S2_MSI_channels = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] unique_channel_list = [] for clf_ids in cb_clf.mk_clf.classifiers_id: unique_channel_list += list(clf_ids) self.unique_channel_ids = np.unique(unique_channel_list) self.unique_channel_str = [self.S2_MSI_channels[ii] for ii in self.unique_channel_ids] def __mk_nv__(self, data): if hasattr(self.cb_clf, "nvc"): bf = self.cb_clf.mk_clf(data.reshape((-1, data.shape[-1]))) nv = self.cb_clf.nvc.predict(bf).reshape(data.shape[:2]) nv[nv < 0] = 0 nv[nv > 0] = 1 return np.array(nv, dtype=np.uint8) else: return None def __call__(self, img, target_resolution=None, majority_filter_options=None, nodata_value=255): if img.target_resolution is None: channel_ids = np.unique([item for sublist in self.cb_clf.mk_clf.classifiers_id_full for item in sublist]) cb_channels = [self.cb_clf.mk_clf.id2name[channel_id] for channel_id in channel_ids] self.cb_clf.mk_clf.adjust_classifier_ids(full_bands=self.cb_clf.mk_clf.id2name, band_lists=cb_channels) data = img.image_subsample(channels=cb_channels, target_resolution=target_resolution) nv = self.__mk_nv__(data) good_data = img.nodata[target_resolution] == np.False_ bad_values = img.nodata[target_resolution] mask_shape = [img.metadata["spatial_samplings"][target_resolution][ii] for ii in ["NCOLS", "NROWS"]] mask_array = np.empty(mask_shape, dtype=np.float32) mask_conf = np.empty(mask_shape, dtype=np.float32) mask_array[:] = np.nan mask_conf[:] = np.nan if self.processing_tiles == 0: mask_array[good_data], mask_conf[good_data] = self.cb_clf.predict_and_conf( data[good_data, :], bad_data_value=nodata_value) else: line_segs = np.linspace(0, mask_shape[0], self.processing_tiles, dtype=int) for ii, (i1, i2) in enumerate(zip(line_segs[:-1], line_segs[1:])): self.logger.info("Processing lines segment %i of %i -> %i:%i" % (ii + 1, self.processing_tiles, i1, i2)) ma, mc = self.cb_clf.predict_and_conf( data[i1:i2, :, :][good_data[i1:i2, :], :], bad_data_value=nodata_value) maf = np.empty(good_data[i1:i2, :].shape, dtype=np.float32) mcf = np.empty(good_data[i1:i2, :].shape, dtype=np.float32) maf[:] = np.nan mcf[:] = np.nan maf[good_data[i1:i2, :]], mcf[good_data[i1:i2, :]] = ma, mc mask_array[i1:i2, :], mask_conf[i1:i2, :] = maf, mcf else: if target_resolution is not None: raise ValueError("target_resolution should only be given if target_resolution=None for the S2 image.") self.cb_clf.mk_clf.adjust_classifier_ids(full_bands=img.full_band_list, band_lists=img.band_list) if self.processing_tiles == 0: mask_array, mask_conf = self.cb_clf.predict_and_conf(img.data, bad_data_value=nodata_value) mask_array = np.array(mask_array, dtype=float) mask_conf = np.array(mask_conf, dtype=float) else: mask_array = np.empty(img.data.shape[:2], dtype=np.float32) mask_conf = np.empty(img.data.shape[:2], dtype=np.float32) mask_array[:] = np.nan mask_conf[:] = np.nan line_segs = np.linspace(0, img.data.shape[0], self.processing_tiles, dtype=int) for ii, (i1, i2) in enumerate(zip(line_segs[:-1], line_segs[1:])): self.logger.info( "Processing lines segment %i of %i -> %i:%i" % (ii + 1, self.processing_tiles, i1, i2)) mask_array[i1:i2, :], mask_conf[i1:i2, :] = self.cb_clf.predict_and_conf( img.data[i1:i2, :, :], bad_data_value=nodata_value) bad_values = np.sum(np.isnan(img.data), axis=2) != 0 nv = self.__mk_nv__(img.data) # conversion to final data type mask_array = np.array(mask_array, dtype=np.uint8) mask_array[bad_values] = nodata_value mask_conf[bad_values] = nodata_value if nv is not None: nv[bad_values] = nodata_value if majority_filter_options is not None: self.logger.info("Applying majority filter:%s" % str(majority_filter_options)) if type(majority_filter_options) is dict: mask_array = majority_mask_filter(mask_array, **majority_filter_options) else: for opts in majority_filter_options: mask_array = majority_mask_filter(mask_array, **opts) uvals = np.unique(mask_array) # unique values avals = list(self.cb_clf.classes) + [nodata_value] # allows values are class ids and the nodata_value for uval in uvals: if uval not in avals: raise ValueError("Value:%f encountered in mask array which is now allowed." % float(uval)) gc = img.metadata["spatial_samplings"][target_resolution] if target_resolution is not None else None return S2Mask(img=img, mask_array=mask_array, clf_to_col=self.clf_to_col, novelty=nv, mask_legend=self.mask_legend, mask_confidence_array=mask_conf, geo_coding=gc)