Source code for enpt.processors.orthorectification.orthorectification

# -*- coding: utf-8 -*-

# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data
#
# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler
# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de),
# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de)
#
# This software was developed within the context of the EnMAP project supported
# by the DLR Space Administration with funds of the German Federal Ministry of
# Economic Affairs and Energy (on the basis of a decision by the German Bundestag:
# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG.
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version. Please note the following exception: `EnPT` depends on tqdm, which
# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program. If not, see <https://www.gnu.org/licenses/>.

"""EnPT module 'orthorectification' for transforming an EnMAP image from sensor to map geometry
based on a pixel- and band-wise coordinate-layer (geolayer).
"""


from typing import Tuple, Union  # noqa: F401
from types import SimpleNamespace

import numpy as np
from pyproj import Geod
from mvgavg import mvgavg
from geoarray import GeoArray
from py_tools_ds.geo.coord_trafo import transform_any_prj
from py_tools_ds.geo.projection import prj_equal

from ...options.config import EnPTConfig
from ...model.images import EnMAPL1Product_SensorGeo, EnMAPL2Product_MapGeo
from ...model.metadata import EnMAP_Metadata_L2A_MapGeo
from ..spatial_transform import \
    Geometry_Transformer, \
    move_extent_to_coord_grid

__author__ = 'Daniel Scheffler'


[docs] class Orthorectifier(object): def __init__(self, config: EnPTConfig): """Create an instance of Orthorectifier.""" self.cfg = config
[docs] @staticmethod def validate_input(enmap_ImageL1: EnMAPL1Product_SensorGeo): # check type if not isinstance(enmap_ImageL1, EnMAPL1Product_SensorGeo): raise TypeError(enmap_ImageL1, "The Orthorectifier expects an instance of 'EnMAPL1Product_SensorGeo'." "Received a '%s' instance." % type(enmap_ImageL1)) # check geolayer shapes for detector in [enmap_ImageL1.vnir, enmap_ImageL1.swir]: for XY in [detector.detector_meta.lons, detector.detector_meta.lats]: datashape = detector.data.shape if XY.shape not in [datashape, datashape[:2]]: raise RuntimeError('Expected a %s geolayer shape of %s or %s. Received %s.' % (detector.detector_name, str(datashape), str(datashape[:2]), str(XY.shape)))
[docs] @staticmethod def get_enmap_coordinate_grid_ll(lon: float, lat: float ) -> (Tuple[float, float], Tuple[float, float]): """Return EnMAP-like (30x30m) longitude/latitude pixel grid specs at the given position.""" geod = Geod(ellps="WGS84") delta_lon = abs(lon - geod.fwd(lon, lat, az=90, dist=30)[0]) delta_lat = abs(lat - geod.fwd(lon, lat, az=0, dist=30)[1]) return (lon, lon + delta_lon), (lat, lat + delta_lat)
[docs] def run_transformation(self, enmap_ImageL1: EnMAPL1Product_SensorGeo) -> EnMAPL2Product_MapGeo: self.validate_input(enmap_ImageL1) enmap_ImageL1.logger.info('Starting orthorectification...') # get a new instance of EnMAPL2Product_MapGeo L2_obj = EnMAPL2Product_MapGeo(config=self.cfg, logger=enmap_ImageL1.logger) # geometric transformations # ############################# lons_vnir, lats_vnir = enmap_ImageL1.vnir.detector_meta.lons, enmap_ImageL1.vnir.detector_meta.lats lons_swir, lats_swir = enmap_ImageL1.swir.detector_meta.lons, enmap_ImageL1.swir.detector_meta.lats if not enmap_ImageL1.vnir.detector_meta.geolayer_has_keystone and lons_vnir.ndim == 3: lons_vnir, lats_vnir = lons_vnir[:, :, 0], lats_vnir[:, :, 0] if not enmap_ImageL1.swir.detector_meta.geolayer_has_keystone and lons_swir.ndim == 3: lons_swir, lats_swir = lons_swir[:, :, 0], lats_swir[:, :, 0] # get target EPSG code and common extent # (VNIR/SWIR overlap, i.e., INNER extent - non-overlapping parts are cleared later) tgt_epsg = enmap_ImageL1.meta.vnir.epsg_ortho tgt_extent = self._get_common_extent(enmap_ImageL1, tgt_epsg, enmap_grid=True) # set up parameters for Geometry_Transformer initialization and execution of the transformation kw_init = dict( backend='gdal' if self.cfg.ortho_resampAlg != 'gauss' else 'pyresample', resamp_alg=self.cfg.ortho_resampAlg, nprocs=self.cfg.CPUs ) kw_trafo = dict( tgt_prj=tgt_epsg, tgt_extent=tgt_extent, tgt_coordgrid=((self.cfg.target_coord_grid['x'], self.cfg.target_coord_grid['y']) if self.cfg.target_coord_grid else None), src_nodata=enmap_ImageL1.vnir.data.nodata, tgt_nodata=self.cfg.output_nodata_value ) # make sure VNIR and SWIR are also transformed to the same lon/lat pixel grid if self.cfg.target_projection_type == 'Geographic' and kw_trafo['tgt_coordgrid'] is None: center_row, center_col = lons_vnir.shape[0] // 2, lons_vnir.shape[1] // 2 center_lon, center_lat = lons_vnir[center_row, center_col], lats_vnir[center_row, center_col] kw_trafo['tgt_coordgrid'] = self.get_enmap_coordinate_grid_ll(center_lon, center_lat) # transform VNIR and SWIR to map geometry enmap_ImageL1.logger.info("Orthorectifying VNIR data using '%s' resampling algorithm..." % self.cfg.ortho_resampAlg) GT_vnir = Geometry_Transformer(lons=lons_vnir, lats=lats_vnir, **kw_init) vnir_mapgeo_gA = GeoArray(*GT_vnir.to_map_geometry(enmap_ImageL1.vnir.data, **kw_trafo), nodata=self.cfg.output_nodata_value) enmap_ImageL1.logger.info("Orthorectifying SWIR data using '%s' resampling algorithm..." % self.cfg.ortho_resampAlg) GT_swir = Geometry_Transformer(lons=lons_swir, lats=lats_swir, **kw_init) swir_mapgeo_gA = GeoArray(*GT_swir.to_map_geometry(enmap_ImageL1.swir.data, **kw_trafo), nodata=self.cfg.output_nodata_value) # combine VNIR and SWIR enmap_ImageL1.logger.info('Merging VNIR and SWIR data...') L2_obj.data = VNIR_SWIR_Stacker(vnir=vnir_mapgeo_gA, swir=swir_mapgeo_gA, vnir_wvls=enmap_ImageL1.meta.vnir.wvl_center, swir_wvls=enmap_ImageL1.meta.swir.wvl_center ).compute_stack(algorithm=self.cfg.vswir_overlap_algorithm) # transform masks and additional AC results from Acwater/Polymer # ################################################################## # TODO allow to set geolayer band to be used for warping of 2D arrays # always use nearest neighbour resampling for masks and bitmasks with discrete values rsp_nearest_list = ['mask_landwater', 'mask_clouds', 'mask_cloudshadow', 'mask_haze', 'mask_snow', 'mask_cirrus', 'polymer_bitmask'] kw_init_nearest = dict(backend='gdal', resamp_alg='nearest', nprocs=self.cfg.CPUs) # run the orthorectification for attrName in ['mask_landwater', 'mask_clouds', 'mask_cloudshadow', 'mask_haze', 'mask_snow', 'mask_cirrus', 'polymer_logchl', 'polymer_logfb', 'polymer_rgli', 'polymer_rnir', 'polymer_bitmask']: attr = getattr(enmap_ImageL1.vnir, attrName) if attr is not None: kw_init_attr = kw_init.copy() if attrName not in rsp_nearest_list else kw_init_nearest kw_trafo_attr = kw_trafo.copy() kw_trafo_attr['src_nodata'] = attr.nodata kw_trafo_attr['tgt_nodata'] = attr.nodata GT = Geometry_Transformer( lons=lons_vnir if lons_vnir.ndim == 2 else lons_vnir[:, :, 0], lats=lats_vnir if lats_vnir.ndim == 2 else lats_vnir[:, :, 0], **kw_init_attr) enmap_ImageL1.logger.info("Orthorectifying '%s' attribute..." % attrName) attr_ortho = GeoArray(*GT.to_map_geometry(attr, **kw_trafo_attr), nodata=attr.nodata) setattr(L2_obj, attrName, attr_ortho) # TODO transform dead pixel map, quality test flags? # set all pixels to nodata that don't have values in all bands # ################################################################ enmap_ImageL1.logger.info("Setting all pixels to nodata that have values in the VNIR or the SWIR only...") mask_nodata_common = np.all(np.dstack([vnir_mapgeo_gA.mask_nodata[:], swir_mapgeo_gA.mask_nodata[:]]), axis=2) L2_obj.data[~mask_nodata_common] = L2_obj.data.nodata for attr_gA in [L2_obj.mask_landwater, L2_obj.mask_clouds, L2_obj.mask_cloudshadow, L2_obj.mask_haze, L2_obj.mask_snow, L2_obj.mask_cirrus]: if attr_gA is not None: attr_gA[~mask_nodata_common] = attr_gA.nodata # metadata adjustments # ######################## enmap_ImageL1.logger.info('Generating L2A metadata...') L2_obj.meta = EnMAP_Metadata_L2A_MapGeo(config=self.cfg, meta_l1b=enmap_ImageL1.meta, wvls_l2a=L2_obj.data.meta.band_meta['wavelength'], dims_mapgeo=L2_obj.data.shape, grid_res_l2a=(L2_obj.data.gt[1], abs(L2_obj.data.gt[5])), logger=L2_obj.logger) L2_obj.meta.add_band_statistics(L2_obj.data) L2_obj.data.meta.band_meta['fwhm'] = list(L2_obj.meta.fwhm) L2_obj.data.meta.global_meta['wavelength_units'] = 'nanometers' # Get the paths according information delivered in the metadata L2_obj.paths = L2_obj.get_paths(self.cfg.output_dir) return L2_obj
[docs] def _get_common_extent(self, enmap_ImageL1: EnMAPL1Product_SensorGeo, tgt_epsg: int, enmap_grid: bool = True) -> Tuple[float, float, float, float]: """Get common target extent for VNIR and SWIR. :para enmap_ImageL1: :param tgt_epsg: :param enmap_grid: :return: """ # get geolayers - 2D for dummy data format else 3D V_lons, V_lats = enmap_ImageL1.meta.vnir.lons, enmap_ImageL1.meta.vnir.lats S_lons, S_lats = enmap_ImageL1.meta.swir.lons, enmap_ImageL1.meta.swir.lats # get Lon/Lat corner coordinates of geolayers V_UL_UR_LL_LR_ll = [(V_lons[y, x], V_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]] S_UL_UR_LL_LR_ll = [(S_lons[y, x], S_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]] # transform them to tgt_epsg if tgt_epsg != 4326: V_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in V_UL_UR_LL_LR_ll] S_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in S_UL_UR_LL_LR_ll] else: V_UL_UR_LL_LR_prj = V_UL_UR_LL_LR_ll S_UL_UR_LL_LR_prj = S_UL_UR_LL_LR_ll # separate X and Y V_X_prj, V_Y_prj = zip(*V_UL_UR_LL_LR_prj) S_X_prj, S_Y_prj = zip(*S_UL_UR_LL_LR_prj) # in case of 3D geolayers, the corner coordinates have multiple values for multiple bands # -> use the innermost coordinates to avoid pixels with VNIR-only/SWIR-only values due to keystone # (these pixels would be set to nodata later anyway, so we don't need to increase the extent for them) if V_lons.ndim == 3: V_X_prj = (V_X_prj[0].max(), V_X_prj[1].min(), V_X_prj[2].max(), V_X_prj[3].min()) V_Y_prj = (V_Y_prj[0].min(), V_Y_prj[1].min(), V_Y_prj[2].max(), V_Y_prj[3].max()) S_X_prj = (S_X_prj[0].max(), S_X_prj[1].min(), S_X_prj[2].max(), S_X_prj[3].min()) S_Y_prj = (S_Y_prj[0].min(), S_Y_prj[1].min(), S_Y_prj[2].max(), S_Y_prj[3].max()) # use inner coordinates of VNIR and SWIR as common extent xmin_prj = max([min(V_X_prj), min(S_X_prj)]) ymin_prj = max([min(V_Y_prj), min(S_Y_prj)]) xmax_prj = min([max(V_X_prj), max(S_X_prj)]) ymax_prj = min([max(V_Y_prj), max(S_Y_prj)]) common_extent_prj = (xmin_prj, ymin_prj, xmax_prj, ymax_prj) # move the extent to the EnMAP coordinate grid if enmap_grid and self.cfg.target_coord_grid: common_extent_prj = move_extent_to_coord_grid(common_extent_prj, self.cfg.target_coord_grid['x'], self.cfg.target_coord_grid['y'],) enmap_ImageL1.logger.info('Computed common target extent of orthorectified image (xmin, ymin, xmax, ymax in ' 'EPSG %s): %s' % (tgt_epsg, str(common_extent_prj))) return common_extent_prj
[docs] class VNIR_SWIR_Stacker(object): def __init__(self, vnir: GeoArray, swir: GeoArray, vnir_wvls: Union[list, np.ndarray], swir_wvls: Union[list, np.ndarray])\ -> None: """Get an instance of VNIR_SWIR_Stacker. :param vnir: :param swir: :param vnir_wvls: :param swir_wvls: """ self.vnir = vnir self.swir = swir self.wvls = SimpleNamespace(vnir=vnir_wvls, swir=swir_wvls) self.wvls.vswir = np.hstack([self.wvls.vnir, self.wvls.swir]) self.wvls.vswir_sorted = np.array(sorted(self.wvls.vswir)) self._validate_input()
[docs] def _validate_input(self): if self.vnir.gt != self.swir.gt: raise ValueError((self.vnir.gt, self.swir.gt), 'VNIR and SWIR geoinformation should be equal.') if not prj_equal(self.vnir.prj, self.swir.prj): raise ValueError((self.vnir.prj, self.swir.prj), 'VNIR and SWIR projection should be equal.') if self.vnir.bands != len(self.wvls.vnir): raise ValueError("The number of VNIR bands must be equal to the number of elements in 'vnir_wvls': " "%d != %d" % (self.vnir.bands, len(self.wvls.vnir))) if self.swir.bands != len(self.wvls.swir): raise ValueError("The number of SWIR bands must be equal to the number of elements in 'swir_wvls': " "%d != %d" % (self.swir.bands, len(self.wvls.swir)))
[docs] def _get_stack_order_by_wvl(self) -> Tuple[np.ndarray, np.ndarray]: """Stack bands ordered by wavelengths.""" bandidx_order = np.array([np.argmin(np.abs(self.wvls.vswir - cwl)) for cwl in self.wvls.vswir_sorted]) return np.dstack([self.vnir[:], self.swir[:]])[:, :, bandidx_order], self.wvls.vswir_sorted
[docs] def _get_stack_average(self, filterwidth: int = 3) -> Tuple[np.ndarray, np.ndarray]: """Stack bands and use averaging to compute the spectral information in the VNIR/SWIR overlap. :param filterwidth: number of bands to be included in the averaging - must be an uneven number """ # FIXME this has to respect nodata values - especially for pixels where one detector has no data. data_stacked = self._get_stack_order_by_wvl()[0] # get wavelenghts and indices of overlapping bands wvls_overlap_vnir = self.wvls.vnir[self.wvls.vnir > self.wvls.swir.min()] wvls_overlap_swir = self.wvls.swir[self.wvls.swir < self.wvls.vnir.max()] wvls_overlap_all = np.array(sorted(np.hstack([wvls_overlap_vnir, wvls_overlap_swir]))) bandidxs_overlap = np.array([np.argmin(np.abs(self.wvls.vswir_sorted - cwl)) for cwl in wvls_overlap_all]) # apply a spectral moving average to the overlapping VNIR/SWIR band bandidxs2average = np.array([np.min(bandidxs_overlap) - int((filterwidth - 1) / 2)] + list(bandidxs_overlap) + [np.max(bandidxs_overlap) + int((filterwidth - 1) / 2)]) data2average = data_stacked[:, :, bandidxs2average] data_stacked[:, :, bandidxs_overlap] = mvgavg(data2average, n=filterwidth, axis=2).astype(data_stacked.dtype) return data_stacked, self.wvls.vswir_sorted
[docs] def _get_stack_vnir_only(self) -> Tuple[np.ndarray, np.ndarray]: """Stack bands while removing overlapping SWIR bands.""" wvls_swir_cut = self.wvls.swir[self.wvls.swir > self.wvls.vnir.max()] wvls_vswir_sorted = np.hstack([self.wvls.vnir, wvls_swir_cut]) idx_swir_firstband = np.argmin(np.abs(self.wvls.swir - wvls_swir_cut.min())) return np.dstack([self.vnir[:], self.swir[:, :, idx_swir_firstband:]]), wvls_vswir_sorted
[docs] def _get_stack_swir_only(self) -> Tuple[np.ndarray, np.ndarray]: """Stack bands while removing overlapping VNIR bands.""" wvls_vnir_cut = self.wvls.vnir[self.wvls.vnir < self.wvls.swir.min()] wvls_vswir_sorted = np.hstack([wvls_vnir_cut, self.wvls.swir]) idx_vnir_lastband = np.argmin(np.abs(self.wvls.vnir - wvls_vnir_cut.max())) return np.dstack([self.vnir[:, :, :idx_vnir_lastband + 1], self.swir[:]]), wvls_vswir_sorted
[docs] def compute_stack(self, algorithm: str) -> GeoArray: """Stack VNIR and SWIR bands with respect to their spectral overlap. :param algorithm: 'order_by_wvl': keep spectral bands unchanged, order bands by wavelength 'average': average the spectral information within the overlap 'vnir_only': only use the VNIR bands (cut overlapping SWIR bands) 'swir_only': only use the SWIR bands (cut overlapping VNIR bands) :return: the stacked data cube as GeoArray instance """ # TODO: This should also set an output nodata value. if algorithm == 'order_by_wvl': data_stacked, wvls = self._get_stack_order_by_wvl() elif algorithm == 'average': data_stacked, wvls = self._get_stack_average() elif algorithm == 'vnir_only': data_stacked, wvls = self._get_stack_vnir_only() elif algorithm == 'swir_only': data_stacked, wvls = self._get_stack_swir_only() else: raise ValueError(algorithm) gA_stacked = GeoArray(data_stacked, geotransform=self.vnir.gt, projection=self.vnir.prj, nodata=self.vnir.nodata) gA_stacked.meta.band_meta['wavelength'] = list(wvls) return gA_stacked