Source code for sicor.Mask.Mask

from matplotlib import pyplot as plt
import numpy as np
from copy import deepcopy
import glymur
from PIL import Image


[docs] class S2Mask(object): def __init__(self, mask_array, mask_legend, img=None, clf_to_col=None, mask_confidence_array=None, geo_coding=None, novelty=None): """ Sentinel-2 MSI masking object :param img:instance of S2_MSI_Image :param mask_array: masking result, numpy array :param mask_legend: dictionary of mask_id:mask_name """ if img is not None: self.metadata = deepcopy(img.metadata) else: self.metadata = {} if geo_coding is not None: self.geo_coding = geo_coding else: try: self.geo_coding = self.metadata["spatial_samplings"][img.target_resolution] except Exception: raise ValueError("If geo_coding is not supplied, S2_img must be given.") self.clf_to_col = clf_to_col self.mask_array = mask_array self.mask_confidence_array = mask_confidence_array self.mask_legend = {key: value for key, value in mask_legend.items()} self.mask_legend_inv = {value: key for key, value in mask_legend.items()} self.novelty = novelty
[docs] def mk_mask_at_spatial_scales(self, flags, samplings): """ use s2msi mask object to create boolean mask at desired spatial sampling :param flags: list of flag names in s2msk which are True for mask :param samplings: list of desired spatial scales :return: """ from scipy.ndimage import zoom # import here to avoid static TLS ImportError assert abs(self.geo_coding["XDIM"]) == abs(self.geo_coding["YDIM"]) tr = float(abs(self.geo_coding["XDIM"])) mask = {tr: np.logical_or.reduce([self.mask_array == self.mask_legend[flag] for flag in flags])} for key in samplings: if key not in mask: zoom_fac = tr / key mask[key] = zoom(mask[tr], zoom=zoom_fac, order=0) return mask
[docs] def mask_rgb_array(self, dtype=np.float16): if self.clf_to_col is not None: mask_rgb = np.zeros(list(self.mask_array.shape) + [3], dtype=dtype) for key, col in self.clf_to_col.items(): if np.issubdtype(dtype, np.uint8): col = np.array(np.array(self.clf_to_col[key]) * 255, dtype=np.uint8) else: col = self.clf_to_col[key] mask_rgb[self.mask_array == key, :] = col return mask_rgb else: return None
[docs] def export_mask_rgb(self, fn_img, rgb_img): mask_rgb = self.mask_rgb_array() dpi = 100.0 fig = plt.figure(figsize=np.array(rgb_img.shape[:2]) / dpi) plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) ax = plt.subplot() ax.imshow(np.array(mask_rgb, dtype=np.float32), interpolation="none") ax.set_axis_off() plt.savefig(fn_img, dpi=dpi) fig.clear() plt.close(fig)
[docs] def export_mask_blend(self, fn_img, rgb_img, alpha=0.6, plot_method="PIL", quality=60): from scipy.ndimage import zoom # import here to avoid static TLS ImportError if alpha > 0.0: mask_rgb = self.mask_rgb_array(dtype=np.uint8) zoom_fac = np.array([s1 / s2 for s1, s2 in zip(rgb_img.shape, mask_rgb.shape)]) if (zoom_fac != [1.0, 1.0, 1.0]).all(): mask_rgb = zoom(input=mask_rgb, order=0, zoom=zoom_fac) if plot_method == "PIL": if alpha > 0.0: img_rgb = Image.fromarray(rgb_img) img_msk = Image.fromarray(mask_rgb) img = Image.blend(img_rgb, img_msk, alpha) img.save(fn_img, quality=quality) else: img_rgb = Image.fromarray(rgb_img) img_rgb.save(fn_img, quality=quality) elif plot_method == "mpl": dpi = 100.0 fig = plt.figure(figsize=np.array(rgb_img.shape[:2]) / dpi) plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) ax = plt.subplot() # RGB image of scene as background ax.imshow(rgb_img, interpolation="none") # mask colors above, but transparent ax.imshow(mask_rgb, interpolation="none", alpha=alpha) ax.set_axis_off() plt.savefig(fn_img, dpi=dpi) fig.clear() plt.close(fig) else: raise ValueError("Plot method: %s not implemented." % str(plot_method))
[docs] def export_confidence_to_jpeg2000(self, fn_img): if self.mask_confidence_array is not None: mask_confidence_array = np.copy(self.mask_confidence_array) mask_confidence_array -= np.nanmin(mask_confidence_array) mask_confidence_array /= np.nanmax(mask_confidence_array) mask_confidence_array *= 100 glymur.Jp2k(fn_img, data=np.array(mask_confidence_array, dtype=np.uint8))
[docs] def export_to_jpeg200(self, fn_img, fn_metadata=None, delimiter=","): if fn_img is not None: glymur.Jp2k(fn_img, data=np.array(self.mask_array, dtype=np.uint8)) if fn_metadata is not None: with open(fn_metadata, 'w') as outfile: for key, value in sorted(self.metadata.items()): value_str = " ".join(str(value).replace("\n", "").split()) if len(value_str) < 100: outfile.write(str(key) + delimiter + value_str + '\n') for key, value in sorted(self.geo_coding.items()): value_str = " ".join(str(value).replace("\n", "").split()) if len(value_str) < 100: outfile.write(str(key) + delimiter + value_str + '\n')