Coverage for enpt/processors/orthorectification/orthorectification.py: 99%
163 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
1# -*- coding: utf-8 -*-
3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data
4#
5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler
6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de),
7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de)
8#
9# This software was developed within the context of the EnMAP project supported
10# by the DLR Space Administration with funds of the German Federal Ministry of
11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag:
12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG.
13#
14# This program is free software: you can redistribute it and/or modify it under
15# the terms of the GNU General Public License as published by the Free Software
16# Foundation, either version 3 of the License, or (at your option) any later
17# version. Please note the following exception: `EnPT` depends on tqdm, which
18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
21#
22# This program is distributed in the hope that it will be useful, but WITHOUT
23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
25# details.
26#
27# You should have received a copy of the GNU Lesser General Public License along
28# with this program. If not, see <https://www.gnu.org/licenses/>.
30"""EnPT module 'orthorectification' for transforming an EnMAP image from sensor to map geometry
31based on a pixel- and band-wise coordinate-layer (geolayer).
32"""
35from typing import Tuple, Union # noqa: F401
36from types import SimpleNamespace
38import numpy as np
39from pyproj import Geod
40from mvgavg import mvgavg
41from geoarray import GeoArray
42from py_tools_ds.geo.coord_trafo import transform_any_prj
43from py_tools_ds.geo.projection import prj_equal
45from ...options.config import EnPTConfig
46from ...model.images import EnMAPL1Product_SensorGeo, EnMAPL2Product_MapGeo
47from ...model.metadata import EnMAP_Metadata_L2A_MapGeo
48from ..spatial_transform import \
49 Geometry_Transformer, \
50 move_extent_to_coord_grid
52__author__ = 'Daniel Scheffler'
55class Orthorectifier(object):
56 def __init__(self, config: EnPTConfig):
57 """Create an instance of Orthorectifier."""
58 self.cfg = config
60 @staticmethod
61 def validate_input(enmap_ImageL1: EnMAPL1Product_SensorGeo):
62 # check type
63 if not isinstance(enmap_ImageL1, EnMAPL1Product_SensorGeo):
64 raise TypeError(enmap_ImageL1, "The Orthorectifier expects an instance of 'EnMAPL1Product_SensorGeo'."
65 "Received a '%s' instance." % type(enmap_ImageL1))
67 # check geolayer shapes
68 for detector in [enmap_ImageL1.vnir, enmap_ImageL1.swir]:
69 for XY in [detector.detector_meta.lons, detector.detector_meta.lats]:
70 datashape = detector.data.shape
71 if XY.shape not in [datashape, datashape[:2]]:
72 raise RuntimeError('Expected a %s geolayer shape of %s or %s. Received %s.'
73 % (detector.detector_name, str(datashape), str(datashape[:2]), str(XY.shape)))
75 @staticmethod
76 def get_enmap_coordinate_grid_ll(lon: float, lat: float
77 ) -> (Tuple[float, float], Tuple[float, float]):
78 """Return EnMAP-like (30x30m) longitude/latitude pixel grid specs at the given position."""
79 geod = Geod(ellps="WGS84")
80 delta_lon = abs(lon - geod.fwd(lon, lat, az=90, dist=30)[0])
81 delta_lat = abs(lat - geod.fwd(lon, lat, az=0, dist=30)[1])
83 return (lon, lon + delta_lon), (lat, lat + delta_lat)
85 def run_transformation(self, enmap_ImageL1: EnMAPL1Product_SensorGeo) -> EnMAPL2Product_MapGeo:
86 self.validate_input(enmap_ImageL1)
88 enmap_ImageL1.logger.info('Starting orthorectification...')
90 # get a new instance of EnMAPL2Product_MapGeo
91 L2_obj = EnMAPL2Product_MapGeo(config=self.cfg, logger=enmap_ImageL1.logger)
93 # geometric transformations #
94 #############################
96 lons_vnir, lats_vnir = enmap_ImageL1.vnir.detector_meta.lons, enmap_ImageL1.vnir.detector_meta.lats
97 lons_swir, lats_swir = enmap_ImageL1.swir.detector_meta.lons, enmap_ImageL1.swir.detector_meta.lats
98 if not enmap_ImageL1.vnir.detector_meta.geolayer_has_keystone and lons_vnir.ndim == 3:
99 lons_vnir, lats_vnir = lons_vnir[:, :, 0], lats_vnir[:, :, 0]
100 if not enmap_ImageL1.swir.detector_meta.geolayer_has_keystone and lons_swir.ndim == 3:
101 lons_swir, lats_swir = lons_swir[:, :, 0], lats_swir[:, :, 0]
103 # get target EPSG code and common extent
104 # (VNIR/SWIR overlap, i.e., INNER extent - non-overlapping parts are cleared later)
105 tgt_epsg = enmap_ImageL1.meta.vnir.epsg_ortho
106 tgt_extent = self._get_common_extent(enmap_ImageL1, tgt_epsg, enmap_grid=True)
108 # set up parameters for Geometry_Transformer initialization and execution of the transformation
109 kw_init = dict(
110 backend='gdal' if self.cfg.ortho_resampAlg != 'gauss' else 'pyresample',
111 resamp_alg=self.cfg.ortho_resampAlg,
112 nprocs=self.cfg.CPUs
113 )
114 kw_trafo = dict(
115 tgt_prj=tgt_epsg,
116 tgt_extent=tgt_extent,
117 tgt_coordgrid=((self.cfg.target_coord_grid['x'],
118 self.cfg.target_coord_grid['y'])
119 if self.cfg.target_coord_grid else
120 None),
121 src_nodata=enmap_ImageL1.vnir.data.nodata,
122 tgt_nodata=self.cfg.output_nodata_value
123 )
124 # make sure VNIR and SWIR are also transformed to the same lon/lat pixel grid
125 if self.cfg.target_projection_type == 'Geographic' and kw_trafo['tgt_coordgrid'] is None:
126 center_row, center_col = lons_vnir.shape[0] // 2, lons_vnir.shape[1] // 2
127 center_lon, center_lat = lons_vnir[center_row, center_col], lats_vnir[center_row, center_col]
128 kw_trafo['tgt_coordgrid'] = self.get_enmap_coordinate_grid_ll(center_lon, center_lat)
130 # transform VNIR and SWIR to map geometry
131 enmap_ImageL1.logger.info("Orthorectifying VNIR data using '%s' resampling algorithm..."
132 % self.cfg.ortho_resampAlg)
133 GT_vnir = Geometry_Transformer(lons=lons_vnir, lats=lats_vnir, **kw_init)
134 vnir_mapgeo_gA = GeoArray(*GT_vnir.to_map_geometry(enmap_ImageL1.vnir.data, **kw_trafo),
135 nodata=self.cfg.output_nodata_value)
137 enmap_ImageL1.logger.info("Orthorectifying SWIR data using '%s' resampling algorithm..."
138 % self.cfg.ortho_resampAlg)
139 GT_swir = Geometry_Transformer(lons=lons_swir, lats=lats_swir, **kw_init)
140 swir_mapgeo_gA = GeoArray(*GT_swir.to_map_geometry(enmap_ImageL1.swir.data, **kw_trafo),
141 nodata=self.cfg.output_nodata_value)
143 # combine VNIR and SWIR
144 enmap_ImageL1.logger.info('Merging VNIR and SWIR data...')
145 L2_obj.data = VNIR_SWIR_Stacker(vnir=vnir_mapgeo_gA,
146 swir=swir_mapgeo_gA,
147 vnir_wvls=enmap_ImageL1.meta.vnir.wvl_center,
148 swir_wvls=enmap_ImageL1.meta.swir.wvl_center
149 ).compute_stack(algorithm=self.cfg.vswir_overlap_algorithm)
151 # transform masks and additional AC results from Acwater/Polymer #
152 ##################################################################
154 # TODO allow to set geolayer band to be used for warping of 2D arrays
156 # always use nearest neighbour resampling for masks and bitmasks with discrete values
157 rsp_nearest_list = ['mask_landwater', 'mask_clouds', 'mask_cloudshadow',
158 'mask_haze', 'mask_snow', 'mask_cirrus', 'polymer_bitmask']
159 kw_init_nearest = dict(backend='gdal', resamp_alg='nearest', nprocs=self.cfg.CPUs)
161 # run the orthorectification
162 for attrName in ['mask_landwater', 'mask_clouds', 'mask_cloudshadow', 'mask_haze', 'mask_snow', 'mask_cirrus',
163 'polymer_logchl', 'polymer_logfb', 'polymer_rgli', 'polymer_rnir', 'polymer_bitmask']:
164 attr = getattr(enmap_ImageL1.vnir, attrName)
166 if attr is not None:
167 kw_init_attr = kw_init.copy() if attrName not in rsp_nearest_list else kw_init_nearest
168 kw_trafo_attr = kw_trafo.copy()
169 kw_trafo_attr['src_nodata'] = attr.nodata
170 kw_trafo_attr['tgt_nodata'] = attr.nodata
172 GT = Geometry_Transformer(
173 lons=lons_vnir if lons_vnir.ndim == 2 else lons_vnir[:, :, 0],
174 lats=lats_vnir if lats_vnir.ndim == 2 else lats_vnir[:, :, 0],
175 **kw_init_attr)
177 enmap_ImageL1.logger.info("Orthorectifying '%s' attribute..." % attrName)
178 attr_ortho = GeoArray(*GT.to_map_geometry(attr, **kw_trafo_attr), nodata=attr.nodata)
179 setattr(L2_obj, attrName, attr_ortho)
181 # TODO transform dead pixel map, quality test flags?
183 # set all pixels to nodata that don't have values in all bands #
184 ################################################################
186 enmap_ImageL1.logger.info("Setting all pixels to nodata that have values in the VNIR or the SWIR only...")
187 mask_nodata_common = np.all(np.dstack([vnir_mapgeo_gA.mask_nodata[:],
188 swir_mapgeo_gA.mask_nodata[:]]), axis=2)
189 L2_obj.data[~mask_nodata_common] = L2_obj.data.nodata
191 for attr_gA in [L2_obj.mask_landwater, L2_obj.mask_clouds, L2_obj.mask_cloudshadow, L2_obj.mask_haze,
192 L2_obj.mask_snow, L2_obj.mask_cirrus]:
193 if attr_gA is not None:
194 attr_gA[~mask_nodata_common] = attr_gA.nodata
196 # metadata adjustments #
197 ########################
199 enmap_ImageL1.logger.info('Generating L2A metadata...')
200 L2_obj.meta = EnMAP_Metadata_L2A_MapGeo(config=self.cfg,
201 meta_l1b=enmap_ImageL1.meta,
202 wvls_l2a=L2_obj.data.meta.band_meta['wavelength'],
203 dims_mapgeo=L2_obj.data.shape,
204 grid_res_l2a=(L2_obj.data.gt[1], abs(L2_obj.data.gt[5])),
205 logger=L2_obj.logger)
206 L2_obj.meta.add_band_statistics(L2_obj.data)
208 L2_obj.data.meta.band_meta['fwhm'] = list(L2_obj.meta.fwhm)
209 L2_obj.data.meta.global_meta['wavelength_units'] = 'nanometers'
211 # Get the paths according information delivered in the metadata
212 L2_obj.paths = L2_obj.get_paths(self.cfg.output_dir)
214 return L2_obj
216 def _get_common_extent(self,
217 enmap_ImageL1: EnMAPL1Product_SensorGeo,
218 tgt_epsg: int,
219 enmap_grid: bool = True) -> Tuple[float, float, float, float]:
220 """Get common target extent for VNIR and SWIR.
222 :para enmap_ImageL1:
223 :param tgt_epsg:
224 :param enmap_grid:
225 :return:
226 """
227 # get geolayers - 2D for dummy data format else 3D
228 V_lons, V_lats = enmap_ImageL1.meta.vnir.lons, enmap_ImageL1.meta.vnir.lats
229 S_lons, S_lats = enmap_ImageL1.meta.swir.lons, enmap_ImageL1.meta.swir.lats
231 # get Lon/Lat corner coordinates of geolayers
232 V_UL_UR_LL_LR_ll = [(V_lons[y, x], V_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]]
233 S_UL_UR_LL_LR_ll = [(S_lons[y, x], S_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]]
235 # transform them to tgt_epsg
236 if tgt_epsg != 4326:
237 V_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in V_UL_UR_LL_LR_ll]
238 S_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in S_UL_UR_LL_LR_ll]
239 else:
240 V_UL_UR_LL_LR_prj = V_UL_UR_LL_LR_ll
241 S_UL_UR_LL_LR_prj = S_UL_UR_LL_LR_ll
243 # separate X and Y
244 V_X_prj, V_Y_prj = zip(*V_UL_UR_LL_LR_prj)
245 S_X_prj, S_Y_prj = zip(*S_UL_UR_LL_LR_prj)
247 # in case of 3D geolayers, the corner coordinates have multiple values for multiple bands
248 # -> use the innermost coordinates to avoid pixels with VNIR-only/SWIR-only values due to keystone
249 # (these pixels would be set to nodata later anyway, so we don't need to increase the extent for them)
250 if V_lons.ndim == 3:
251 V_X_prj = (V_X_prj[0].max(), V_X_prj[1].min(), V_X_prj[2].max(), V_X_prj[3].min())
252 V_Y_prj = (V_Y_prj[0].min(), V_Y_prj[1].min(), V_Y_prj[2].max(), V_Y_prj[3].max())
253 S_X_prj = (S_X_prj[0].max(), S_X_prj[1].min(), S_X_prj[2].max(), S_X_prj[3].min())
254 S_Y_prj = (S_Y_prj[0].min(), S_Y_prj[1].min(), S_Y_prj[2].max(), S_Y_prj[3].max())
256 # use inner coordinates of VNIR and SWIR as common extent
257 xmin_prj = max([min(V_X_prj), min(S_X_prj)])
258 ymin_prj = max([min(V_Y_prj), min(S_Y_prj)])
259 xmax_prj = min([max(V_X_prj), max(S_X_prj)])
260 ymax_prj = min([max(V_Y_prj), max(S_Y_prj)])
261 common_extent_prj = (xmin_prj, ymin_prj, xmax_prj, ymax_prj)
263 # move the extent to the EnMAP coordinate grid
264 if enmap_grid and self.cfg.target_coord_grid:
265 common_extent_prj = move_extent_to_coord_grid(common_extent_prj,
266 self.cfg.target_coord_grid['x'],
267 self.cfg.target_coord_grid['y'],)
269 enmap_ImageL1.logger.info('Computed common target extent of orthorectified image (xmin, ymin, xmax, ymax in '
270 'EPSG %s): %s' % (tgt_epsg, str(common_extent_prj)))
272 return common_extent_prj
275class VNIR_SWIR_Stacker(object):
276 def __init__(self,
277 vnir: GeoArray,
278 swir: GeoArray,
279 vnir_wvls: Union[list, np.ndarray],
280 swir_wvls: Union[list, np.ndarray])\
281 -> None:
282 """Get an instance of VNIR_SWIR_Stacker.
284 :param vnir:
285 :param swir:
286 :param vnir_wvls:
287 :param swir_wvls:
288 """
289 self.vnir = vnir
290 self.swir = swir
291 self.wvls = SimpleNamespace(vnir=vnir_wvls, swir=swir_wvls)
293 self.wvls.vswir = np.hstack([self.wvls.vnir, self.wvls.swir])
294 self.wvls.vswir_sorted = np.array(sorted(self.wvls.vswir))
296 self._validate_input()
298 def _validate_input(self):
299 if self.vnir.gt != self.swir.gt:
300 raise ValueError((self.vnir.gt, self.swir.gt), 'VNIR and SWIR geoinformation should be equal.')
301 if not prj_equal(self.vnir.prj, self.swir.prj):
302 raise ValueError((self.vnir.prj, self.swir.prj), 'VNIR and SWIR projection should be equal.')
303 if self.vnir.bands != len(self.wvls.vnir):
304 raise ValueError("The number of VNIR bands must be equal to the number of elements in 'vnir_wvls': "
305 "%d != %d" % (self.vnir.bands, len(self.wvls.vnir)))
306 if self.swir.bands != len(self.wvls.swir):
307 raise ValueError("The number of SWIR bands must be equal to the number of elements in 'swir_wvls': "
308 "%d != %d" % (self.swir.bands, len(self.wvls.swir)))
310 def _get_stack_order_by_wvl(self) -> Tuple[np.ndarray, np.ndarray]:
311 """Stack bands ordered by wavelengths."""
312 bandidx_order = np.array([np.argmin(np.abs(self.wvls.vswir - cwl))
313 for cwl in self.wvls.vswir_sorted])
315 return np.dstack([self.vnir[:], self.swir[:]])[:, :, bandidx_order], self.wvls.vswir_sorted
317 def _get_stack_average(self, filterwidth: int = 3) -> Tuple[np.ndarray, np.ndarray]:
318 """Stack bands and use averaging to compute the spectral information in the VNIR/SWIR overlap.
320 :param filterwidth: number of bands to be included in the averaging - must be an uneven number
321 """
322 # FIXME this has to respect nodata values - especially for pixels where one detector has no data.
323 data_stacked = self._get_stack_order_by_wvl()[0]
325 # get wavelenghts and indices of overlapping bands
326 wvls_overlap_vnir = self.wvls.vnir[self.wvls.vnir > self.wvls.swir.min()]
327 wvls_overlap_swir = self.wvls.swir[self.wvls.swir < self.wvls.vnir.max()]
328 wvls_overlap_all = np.array(sorted(np.hstack([wvls_overlap_vnir,
329 wvls_overlap_swir])))
330 bandidxs_overlap = np.array([np.argmin(np.abs(self.wvls.vswir_sorted - cwl))
331 for cwl in wvls_overlap_all])
333 # apply a spectral moving average to the overlapping VNIR/SWIR band
334 bandidxs2average = np.array([np.min(bandidxs_overlap) - int((filterwidth - 1) / 2)] +
335 list(bandidxs_overlap) +
336 [np.max(bandidxs_overlap) + int((filterwidth - 1) / 2)])
337 data2average = data_stacked[:, :, bandidxs2average]
338 data_stacked[:, :, bandidxs_overlap] = mvgavg(data2average,
339 n=filterwidth,
340 axis=2).astype(data_stacked.dtype)
342 return data_stacked, self.wvls.vswir_sorted
344 def _get_stack_vnir_only(self) -> Tuple[np.ndarray, np.ndarray]:
345 """Stack bands while removing overlapping SWIR bands."""
346 wvls_swir_cut = self.wvls.swir[self.wvls.swir > self.wvls.vnir.max()]
347 wvls_vswir_sorted = np.hstack([self.wvls.vnir, wvls_swir_cut])
348 idx_swir_firstband = np.argmin(np.abs(self.wvls.swir - wvls_swir_cut.min()))
350 return np.dstack([self.vnir[:], self.swir[:, :, idx_swir_firstband:]]), wvls_vswir_sorted
352 def _get_stack_swir_only(self) -> Tuple[np.ndarray, np.ndarray]:
353 """Stack bands while removing overlapping VNIR bands."""
354 wvls_vnir_cut = self.wvls.vnir[self.wvls.vnir < self.wvls.swir.min()]
355 wvls_vswir_sorted = np.hstack([wvls_vnir_cut, self.wvls.swir])
356 idx_vnir_lastband = np.argmin(np.abs(self.wvls.vnir - wvls_vnir_cut.max()))
358 return np.dstack([self.vnir[:, :, :idx_vnir_lastband + 1], self.swir[:]]), wvls_vswir_sorted
360 def compute_stack(self, algorithm: str) -> GeoArray:
361 """Stack VNIR and SWIR bands with respect to their spectral overlap.
363 :param algorithm: 'order_by_wvl': keep spectral bands unchanged, order bands by wavelength
364 'average': average the spectral information within the overlap
365 'vnir_only': only use the VNIR bands (cut overlapping SWIR bands)
366 'swir_only': only use the SWIR bands (cut overlapping VNIR bands)
367 :return: the stacked data cube as GeoArray instance
368 """
369 # TODO: This should also set an output nodata value.
370 if algorithm == 'order_by_wvl':
371 data_stacked, wvls = self._get_stack_order_by_wvl()
372 elif algorithm == 'average':
373 data_stacked, wvls = self._get_stack_average()
374 elif algorithm == 'vnir_only':
375 data_stacked, wvls = self._get_stack_vnir_only()
376 elif algorithm == 'swir_only':
377 data_stacked, wvls = self._get_stack_swir_only()
378 else:
379 raise ValueError(algorithm)
381 gA_stacked = GeoArray(data_stacked,
382 geotransform=self.vnir.gt, projection=self.vnir.prj, nodata=self.vnir.nodata)
383 gA_stacked.meta.band_meta['wavelength'] = list(wvls)
385 return gA_stacked