Coverage for enpt/processors/orthorectification/orthorectification.py: 99%

163 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-03-07 11:39 +0000

1# -*- coding: utf-8 -*- 

2 

3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data 

4# 

5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler 

6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de), 

7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de) 

8# 

9# This software was developed within the context of the EnMAP project supported 

10# by the DLR Space Administration with funds of the German Federal Ministry of 

11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag: 

12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG. 

13# 

14# This program is free software: you can redistribute it and/or modify it under 

15# the terms of the GNU General Public License as published by the Free Software 

16# Foundation, either version 3 of the License, or (at your option) any later 

17# version. Please note the following exception: `EnPT` depends on tqdm, which 

18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

21# 

22# This program is distributed in the hope that it will be useful, but WITHOUT 

23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 

24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

25# details. 

26# 

27# You should have received a copy of the GNU Lesser General Public License along 

28# with this program. If not, see <https://www.gnu.org/licenses/>. 

29 

30"""EnPT module 'orthorectification' for transforming an EnMAP image from sensor to map geometry 

31based on a pixel- and band-wise coordinate-layer (geolayer). 

32""" 

33 

34 

35from typing import Tuple, Union # noqa: F401 

36from types import SimpleNamespace 

37 

38import numpy as np 

39from pyproj import Geod 

40from mvgavg import mvgavg 

41from geoarray import GeoArray 

42from py_tools_ds.geo.coord_trafo import transform_any_prj 

43from py_tools_ds.geo.projection import prj_equal 

44 

45from ...options.config import EnPTConfig 

46from ...model.images import EnMAPL1Product_SensorGeo, EnMAPL2Product_MapGeo 

47from ...model.metadata import EnMAP_Metadata_L2A_MapGeo 

48from ..spatial_transform import \ 

49 Geometry_Transformer, \ 

50 move_extent_to_coord_grid 

51 

52__author__ = 'Daniel Scheffler' 

53 

54 

55class Orthorectifier(object): 

56 def __init__(self, config: EnPTConfig): 

57 """Create an instance of Orthorectifier.""" 

58 self.cfg = config 

59 

60 @staticmethod 

61 def validate_input(enmap_ImageL1: EnMAPL1Product_SensorGeo): 

62 # check type 

63 if not isinstance(enmap_ImageL1, EnMAPL1Product_SensorGeo): 

64 raise TypeError(enmap_ImageL1, "The Orthorectifier expects an instance of 'EnMAPL1Product_SensorGeo'." 

65 "Received a '%s' instance." % type(enmap_ImageL1)) 

66 

67 # check geolayer shapes 

68 for detector in [enmap_ImageL1.vnir, enmap_ImageL1.swir]: 

69 for XY in [detector.detector_meta.lons, detector.detector_meta.lats]: 

70 datashape = detector.data.shape 

71 if XY.shape not in [datashape, datashape[:2]]: 

72 raise RuntimeError('Expected a %s geolayer shape of %s or %s. Received %s.' 

73 % (detector.detector_name, str(datashape), str(datashape[:2]), str(XY.shape))) 

74 

75 @staticmethod 

76 def get_enmap_coordinate_grid_ll(lon: float, lat: float 

77 ) -> (Tuple[float, float], Tuple[float, float]): 

78 """Return EnMAP-like (30x30m) longitude/latitude pixel grid specs at the given position.""" 

79 geod = Geod(ellps="WGS84") 

80 delta_lon = abs(lon - geod.fwd(lon, lat, az=90, dist=30)[0]) 

81 delta_lat = abs(lat - geod.fwd(lon, lat, az=0, dist=30)[1]) 

82 

83 return (lon, lon + delta_lon), (lat, lat + delta_lat) 

84 

85 def run_transformation(self, enmap_ImageL1: EnMAPL1Product_SensorGeo) -> EnMAPL2Product_MapGeo: 

86 self.validate_input(enmap_ImageL1) 

87 

88 enmap_ImageL1.logger.info('Starting orthorectification...') 

89 

90 # get a new instance of EnMAPL2Product_MapGeo 

91 L2_obj = EnMAPL2Product_MapGeo(config=self.cfg, logger=enmap_ImageL1.logger) 

92 

93 # geometric transformations # 

94 ############################# 

95 

96 lons_vnir, lats_vnir = enmap_ImageL1.vnir.detector_meta.lons, enmap_ImageL1.vnir.detector_meta.lats 

97 lons_swir, lats_swir = enmap_ImageL1.swir.detector_meta.lons, enmap_ImageL1.swir.detector_meta.lats 

98 if not enmap_ImageL1.vnir.detector_meta.geolayer_has_keystone and lons_vnir.ndim == 3: 

99 lons_vnir, lats_vnir = lons_vnir[:, :, 0], lats_vnir[:, :, 0] 

100 if not enmap_ImageL1.swir.detector_meta.geolayer_has_keystone and lons_swir.ndim == 3: 

101 lons_swir, lats_swir = lons_swir[:, :, 0], lats_swir[:, :, 0] 

102 

103 # get target EPSG code and common extent 

104 # (VNIR/SWIR overlap, i.e., INNER extent - non-overlapping parts are cleared later) 

105 tgt_epsg = enmap_ImageL1.meta.vnir.epsg_ortho 

106 tgt_extent = self._get_common_extent(enmap_ImageL1, tgt_epsg, enmap_grid=True) 

107 

108 # set up parameters for Geometry_Transformer initialization and execution of the transformation 

109 kw_init = dict( 

110 backend='gdal' if self.cfg.ortho_resampAlg != 'gauss' else 'pyresample', 

111 resamp_alg=self.cfg.ortho_resampAlg, 

112 nprocs=self.cfg.CPUs 

113 ) 

114 kw_trafo = dict( 

115 tgt_prj=tgt_epsg, 

116 tgt_extent=tgt_extent, 

117 tgt_coordgrid=((self.cfg.target_coord_grid['x'], 

118 self.cfg.target_coord_grid['y']) 

119 if self.cfg.target_coord_grid else 

120 None), 

121 src_nodata=enmap_ImageL1.vnir.data.nodata, 

122 tgt_nodata=self.cfg.output_nodata_value 

123 ) 

124 # make sure VNIR and SWIR are also transformed to the same lon/lat pixel grid 

125 if self.cfg.target_projection_type == 'Geographic' and kw_trafo['tgt_coordgrid'] is None: 

126 center_row, center_col = lons_vnir.shape[0] // 2, lons_vnir.shape[1] // 2 

127 center_lon, center_lat = lons_vnir[center_row, center_col], lats_vnir[center_row, center_col] 

128 kw_trafo['tgt_coordgrid'] = self.get_enmap_coordinate_grid_ll(center_lon, center_lat) 

129 

130 # transform VNIR and SWIR to map geometry 

131 enmap_ImageL1.logger.info("Orthorectifying VNIR data using '%s' resampling algorithm..." 

132 % self.cfg.ortho_resampAlg) 

133 GT_vnir = Geometry_Transformer(lons=lons_vnir, lats=lats_vnir, **kw_init) 

134 vnir_mapgeo_gA = GeoArray(*GT_vnir.to_map_geometry(enmap_ImageL1.vnir.data, **kw_trafo), 

135 nodata=self.cfg.output_nodata_value) 

136 

137 enmap_ImageL1.logger.info("Orthorectifying SWIR data using '%s' resampling algorithm..." 

138 % self.cfg.ortho_resampAlg) 

139 GT_swir = Geometry_Transformer(lons=lons_swir, lats=lats_swir, **kw_init) 

140 swir_mapgeo_gA = GeoArray(*GT_swir.to_map_geometry(enmap_ImageL1.swir.data, **kw_trafo), 

141 nodata=self.cfg.output_nodata_value) 

142 

143 # combine VNIR and SWIR 

144 enmap_ImageL1.logger.info('Merging VNIR and SWIR data...') 

145 L2_obj.data = VNIR_SWIR_Stacker(vnir=vnir_mapgeo_gA, 

146 swir=swir_mapgeo_gA, 

147 vnir_wvls=enmap_ImageL1.meta.vnir.wvl_center, 

148 swir_wvls=enmap_ImageL1.meta.swir.wvl_center 

149 ).compute_stack(algorithm=self.cfg.vswir_overlap_algorithm) 

150 

151 # transform masks and additional AC results from Acwater/Polymer # 

152 ################################################################## 

153 

154 # TODO allow to set geolayer band to be used for warping of 2D arrays 

155 

156 # always use nearest neighbour resampling for masks and bitmasks with discrete values 

157 rsp_nearest_list = ['mask_landwater', 'mask_clouds', 'mask_cloudshadow', 

158 'mask_haze', 'mask_snow', 'mask_cirrus', 'polymer_bitmask'] 

159 kw_init_nearest = dict(backend='gdal', resamp_alg='nearest', nprocs=self.cfg.CPUs) 

160 

161 # run the orthorectification 

162 for attrName in ['mask_landwater', 'mask_clouds', 'mask_cloudshadow', 'mask_haze', 'mask_snow', 'mask_cirrus', 

163 'polymer_logchl', 'polymer_logfb', 'polymer_rgli', 'polymer_rnir', 'polymer_bitmask']: 

164 attr = getattr(enmap_ImageL1.vnir, attrName) 

165 

166 if attr is not None: 

167 kw_init_attr = kw_init.copy() if attrName not in rsp_nearest_list else kw_init_nearest 

168 kw_trafo_attr = kw_trafo.copy() 

169 kw_trafo_attr['src_nodata'] = attr.nodata 

170 kw_trafo_attr['tgt_nodata'] = attr.nodata 

171 

172 GT = Geometry_Transformer( 

173 lons=lons_vnir if lons_vnir.ndim == 2 else lons_vnir[:, :, 0], 

174 lats=lats_vnir if lats_vnir.ndim == 2 else lats_vnir[:, :, 0], 

175 **kw_init_attr) 

176 

177 enmap_ImageL1.logger.info("Orthorectifying '%s' attribute..." % attrName) 

178 attr_ortho = GeoArray(*GT.to_map_geometry(attr, **kw_trafo_attr), nodata=attr.nodata) 

179 setattr(L2_obj, attrName, attr_ortho) 

180 

181 # TODO transform dead pixel map, quality test flags? 

182 

183 # set all pixels to nodata that don't have values in all bands # 

184 ################################################################ 

185 

186 enmap_ImageL1.logger.info("Setting all pixels to nodata that have values in the VNIR or the SWIR only...") 

187 mask_nodata_common = np.all(np.dstack([vnir_mapgeo_gA.mask_nodata[:], 

188 swir_mapgeo_gA.mask_nodata[:]]), axis=2) 

189 L2_obj.data[~mask_nodata_common] = L2_obj.data.nodata 

190 

191 for attr_gA in [L2_obj.mask_landwater, L2_obj.mask_clouds, L2_obj.mask_cloudshadow, L2_obj.mask_haze, 

192 L2_obj.mask_snow, L2_obj.mask_cirrus]: 

193 if attr_gA is not None: 

194 attr_gA[~mask_nodata_common] = attr_gA.nodata 

195 

196 # metadata adjustments # 

197 ######################## 

198 

199 enmap_ImageL1.logger.info('Generating L2A metadata...') 

200 L2_obj.meta = EnMAP_Metadata_L2A_MapGeo(config=self.cfg, 

201 meta_l1b=enmap_ImageL1.meta, 

202 wvls_l2a=L2_obj.data.meta.band_meta['wavelength'], 

203 dims_mapgeo=L2_obj.data.shape, 

204 grid_res_l2a=(L2_obj.data.gt[1], abs(L2_obj.data.gt[5])), 

205 logger=L2_obj.logger) 

206 L2_obj.meta.add_band_statistics(L2_obj.data) 

207 

208 L2_obj.data.meta.band_meta['fwhm'] = list(L2_obj.meta.fwhm) 

209 L2_obj.data.meta.global_meta['wavelength_units'] = 'nanometers' 

210 

211 # Get the paths according information delivered in the metadata 

212 L2_obj.paths = L2_obj.get_paths(self.cfg.output_dir) 

213 

214 return L2_obj 

215 

216 def _get_common_extent(self, 

217 enmap_ImageL1: EnMAPL1Product_SensorGeo, 

218 tgt_epsg: int, 

219 enmap_grid: bool = True) -> Tuple[float, float, float, float]: 

220 """Get common target extent for VNIR and SWIR. 

221 

222 :para enmap_ImageL1: 

223 :param tgt_epsg: 

224 :param enmap_grid: 

225 :return: 

226 """ 

227 # get geolayers - 2D for dummy data format else 3D 

228 V_lons, V_lats = enmap_ImageL1.meta.vnir.lons, enmap_ImageL1.meta.vnir.lats 

229 S_lons, S_lats = enmap_ImageL1.meta.swir.lons, enmap_ImageL1.meta.swir.lats 

230 

231 # get Lon/Lat corner coordinates of geolayers 

232 V_UL_UR_LL_LR_ll = [(V_lons[y, x], V_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]] 

233 S_UL_UR_LL_LR_ll = [(S_lons[y, x], S_lats[y, x]) for y, x in [(0, 0), (0, -1), (-1, 0), (-1, -1)]] 

234 

235 # transform them to tgt_epsg 

236 if tgt_epsg != 4326: 

237 V_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in V_UL_UR_LL_LR_ll] 

238 S_UL_UR_LL_LR_prj = [transform_any_prj(4326, tgt_epsg, x, y) for x, y in S_UL_UR_LL_LR_ll] 

239 else: 

240 V_UL_UR_LL_LR_prj = V_UL_UR_LL_LR_ll 

241 S_UL_UR_LL_LR_prj = S_UL_UR_LL_LR_ll 

242 

243 # separate X and Y 

244 V_X_prj, V_Y_prj = zip(*V_UL_UR_LL_LR_prj) 

245 S_X_prj, S_Y_prj = zip(*S_UL_UR_LL_LR_prj) 

246 

247 # in case of 3D geolayers, the corner coordinates have multiple values for multiple bands 

248 # -> use the innermost coordinates to avoid pixels with VNIR-only/SWIR-only values due to keystone 

249 # (these pixels would be set to nodata later anyway, so we don't need to increase the extent for them) 

250 if V_lons.ndim == 3: 

251 V_X_prj = (V_X_prj[0].max(), V_X_prj[1].min(), V_X_prj[2].max(), V_X_prj[3].min()) 

252 V_Y_prj = (V_Y_prj[0].min(), V_Y_prj[1].min(), V_Y_prj[2].max(), V_Y_prj[3].max()) 

253 S_X_prj = (S_X_prj[0].max(), S_X_prj[1].min(), S_X_prj[2].max(), S_X_prj[3].min()) 

254 S_Y_prj = (S_Y_prj[0].min(), S_Y_prj[1].min(), S_Y_prj[2].max(), S_Y_prj[3].max()) 

255 

256 # use inner coordinates of VNIR and SWIR as common extent 

257 xmin_prj = max([min(V_X_prj), min(S_X_prj)]) 

258 ymin_prj = max([min(V_Y_prj), min(S_Y_prj)]) 

259 xmax_prj = min([max(V_X_prj), max(S_X_prj)]) 

260 ymax_prj = min([max(V_Y_prj), max(S_Y_prj)]) 

261 common_extent_prj = (xmin_prj, ymin_prj, xmax_prj, ymax_prj) 

262 

263 # move the extent to the EnMAP coordinate grid 

264 if enmap_grid and self.cfg.target_coord_grid: 

265 common_extent_prj = move_extent_to_coord_grid(common_extent_prj, 

266 self.cfg.target_coord_grid['x'], 

267 self.cfg.target_coord_grid['y'],) 

268 

269 enmap_ImageL1.logger.info('Computed common target extent of orthorectified image (xmin, ymin, xmax, ymax in ' 

270 'EPSG %s): %s' % (tgt_epsg, str(common_extent_prj))) 

271 

272 return common_extent_prj 

273 

274 

275class VNIR_SWIR_Stacker(object): 

276 def __init__(self, 

277 vnir: GeoArray, 

278 swir: GeoArray, 

279 vnir_wvls: Union[list, np.ndarray], 

280 swir_wvls: Union[list, np.ndarray])\ 

281 -> None: 

282 """Get an instance of VNIR_SWIR_Stacker. 

283 

284 :param vnir: 

285 :param swir: 

286 :param vnir_wvls: 

287 :param swir_wvls: 

288 """ 

289 self.vnir = vnir 

290 self.swir = swir 

291 self.wvls = SimpleNamespace(vnir=vnir_wvls, swir=swir_wvls) 

292 

293 self.wvls.vswir = np.hstack([self.wvls.vnir, self.wvls.swir]) 

294 self.wvls.vswir_sorted = np.array(sorted(self.wvls.vswir)) 

295 

296 self._validate_input() 

297 

298 def _validate_input(self): 

299 if self.vnir.gt != self.swir.gt: 

300 raise ValueError((self.vnir.gt, self.swir.gt), 'VNIR and SWIR geoinformation should be equal.') 

301 if not prj_equal(self.vnir.prj, self.swir.prj): 

302 raise ValueError((self.vnir.prj, self.swir.prj), 'VNIR and SWIR projection should be equal.') 

303 if self.vnir.bands != len(self.wvls.vnir): 

304 raise ValueError("The number of VNIR bands must be equal to the number of elements in 'vnir_wvls': " 

305 "%d != %d" % (self.vnir.bands, len(self.wvls.vnir))) 

306 if self.swir.bands != len(self.wvls.swir): 

307 raise ValueError("The number of SWIR bands must be equal to the number of elements in 'swir_wvls': " 

308 "%d != %d" % (self.swir.bands, len(self.wvls.swir))) 

309 

310 def _get_stack_order_by_wvl(self) -> Tuple[np.ndarray, np.ndarray]: 

311 """Stack bands ordered by wavelengths.""" 

312 bandidx_order = np.array([np.argmin(np.abs(self.wvls.vswir - cwl)) 

313 for cwl in self.wvls.vswir_sorted]) 

314 

315 return np.dstack([self.vnir[:], self.swir[:]])[:, :, bandidx_order], self.wvls.vswir_sorted 

316 

317 def _get_stack_average(self, filterwidth: int = 3) -> Tuple[np.ndarray, np.ndarray]: 

318 """Stack bands and use averaging to compute the spectral information in the VNIR/SWIR overlap. 

319 

320 :param filterwidth: number of bands to be included in the averaging - must be an uneven number 

321 """ 

322 # FIXME this has to respect nodata values - especially for pixels where one detector has no data. 

323 data_stacked = self._get_stack_order_by_wvl()[0] 

324 

325 # get wavelenghts and indices of overlapping bands 

326 wvls_overlap_vnir = self.wvls.vnir[self.wvls.vnir > self.wvls.swir.min()] 

327 wvls_overlap_swir = self.wvls.swir[self.wvls.swir < self.wvls.vnir.max()] 

328 wvls_overlap_all = np.array(sorted(np.hstack([wvls_overlap_vnir, 

329 wvls_overlap_swir]))) 

330 bandidxs_overlap = np.array([np.argmin(np.abs(self.wvls.vswir_sorted - cwl)) 

331 for cwl in wvls_overlap_all]) 

332 

333 # apply a spectral moving average to the overlapping VNIR/SWIR band 

334 bandidxs2average = np.array([np.min(bandidxs_overlap) - int((filterwidth - 1) / 2)] + 

335 list(bandidxs_overlap) + 

336 [np.max(bandidxs_overlap) + int((filterwidth - 1) / 2)]) 

337 data2average = data_stacked[:, :, bandidxs2average] 

338 data_stacked[:, :, bandidxs_overlap] = mvgavg(data2average, 

339 n=filterwidth, 

340 axis=2).astype(data_stacked.dtype) 

341 

342 return data_stacked, self.wvls.vswir_sorted 

343 

344 def _get_stack_vnir_only(self) -> Tuple[np.ndarray, np.ndarray]: 

345 """Stack bands while removing overlapping SWIR bands.""" 

346 wvls_swir_cut = self.wvls.swir[self.wvls.swir > self.wvls.vnir.max()] 

347 wvls_vswir_sorted = np.hstack([self.wvls.vnir, wvls_swir_cut]) 

348 idx_swir_firstband = np.argmin(np.abs(self.wvls.swir - wvls_swir_cut.min())) 

349 

350 return np.dstack([self.vnir[:], self.swir[:, :, idx_swir_firstband:]]), wvls_vswir_sorted 

351 

352 def _get_stack_swir_only(self) -> Tuple[np.ndarray, np.ndarray]: 

353 """Stack bands while removing overlapping VNIR bands.""" 

354 wvls_vnir_cut = self.wvls.vnir[self.wvls.vnir < self.wvls.swir.min()] 

355 wvls_vswir_sorted = np.hstack([wvls_vnir_cut, self.wvls.swir]) 

356 idx_vnir_lastband = np.argmin(np.abs(self.wvls.vnir - wvls_vnir_cut.max())) 

357 

358 return np.dstack([self.vnir[:, :, :idx_vnir_lastband + 1], self.swir[:]]), wvls_vswir_sorted 

359 

360 def compute_stack(self, algorithm: str) -> GeoArray: 

361 """Stack VNIR and SWIR bands with respect to their spectral overlap. 

362 

363 :param algorithm: 'order_by_wvl': keep spectral bands unchanged, order bands by wavelength 

364 'average': average the spectral information within the overlap 

365 'vnir_only': only use the VNIR bands (cut overlapping SWIR bands) 

366 'swir_only': only use the SWIR bands (cut overlapping VNIR bands) 

367 :return: the stacked data cube as GeoArray instance 

368 """ 

369 # TODO: This should also set an output nodata value. 

370 if algorithm == 'order_by_wvl': 

371 data_stacked, wvls = self._get_stack_order_by_wvl() 

372 elif algorithm == 'average': 

373 data_stacked, wvls = self._get_stack_average() 

374 elif algorithm == 'vnir_only': 

375 data_stacked, wvls = self._get_stack_vnir_only() 

376 elif algorithm == 'swir_only': 

377 data_stacked, wvls = self._get_stack_swir_only() 

378 else: 

379 raise ValueError(algorithm) 

380 

381 gA_stacked = GeoArray(data_stacked, 

382 geotransform=self.vnir.gt, projection=self.vnir.prj, nodata=self.vnir.nodata) 

383 gA_stacked.meta.band_meta['wavelength'] = list(wvls) 

384 

385 return gA_stacked