Coverage for enpt/processors/spatial_optimization/spatial_optimization.py: 93%

134 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-03-07 11:39 +0000

1# -*- coding: utf-8 -*- 

2 

3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data 

4# 

5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler 

6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de), 

7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de) 

8# 

9# This software was developed within the context of the EnMAP project supported 

10# by the DLR Space Administration with funds of the German Federal Ministry of 

11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag: 

12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG. 

13# 

14# This program is free software: you can redistribute it and/or modify it under 

15# the terms of the GNU General Public License as published by the Free Software 

16# Foundation, either version 3 of the License, or (at your option) any later 

17# version. Please note the following exception: `EnPT` depends on tqdm, which 

18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

21# 

22# This program is distributed in the hope that it will be useful, but WITHOUT 

23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 

24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

25# details. 

26# 

27# You should have received a copy of the GNU Lesser General Public License along 

28# with this program. If not, see <https://www.gnu.org/licenses/>. 

29 

30"""EnPT spatial optimization module. 

31 

32Adapts the EnMAP image geometry to a given Sentinel-2 L2A dataset. 

33Fits the VNIR detector data to the reference image. Corrects for keystone. 

34""" 

35 

36__author__ = 'Daniel Scheffler' 

37 

38import os 

39from typing import Optional 

40 

41import numpy as np 

42from osgeo import gdal # noqa 

43 

44from arosics import COREG_LOCAL 

45from geoarray import GeoArray 

46from py_tools_ds.geo.coord_trafo import reproject_shapelyGeometry, transform_coordArray, get_utm_zone 

47from py_tools_ds.geo.projection import EPSG2WKT 

48from py_tools_ds.processing.progress_mon import ProgressBar 

49 

50from ...options.config import EnPTConfig 

51from ...model.images.images_sensorgeo import EnMAPL1Product_SensorGeo 

52from ..spatial_transform import Geometry_Transformer 

53 

54 

55class Spatial_Optimizer(object): 

56 def __init__(self, config: EnPTConfig): 

57 """Create an instance of Spatial_Optimizer.""" 

58 self.cfg = config 

59 self._ref_Im: Optional[GeoArray, None] = GeoArray(self.cfg.path_reference_image) 

60 

61 self._EnMAP_Im: Optional[EnMAPL1Product_SensorGeo, None] = None 

62 self._EnMAP_band: Optional[GeoArray, None] = None 

63 self._EnMAP_mask: Optional[GeoArray, None] = None 

64 self._ref_band_prep: Optional[GeoArray, None] = None 

65 self._EnMAP_bandIdx = 39 # FIXME hardcoded 

66 

67 def _get_enmap_band_for_matching(self) \ 

68 -> GeoArray: 

69 """Return the EnMAP band to be used in co-registration in UTM projection at 15m resolution.""" 

70 self._EnMAP_Im.logger.warning(f'Statically using band {self._EnMAP_bandIdx + 1} for co-registration.') 

71 

72 enmap_band_sensorgeo = self._EnMAP_Im.vnir.data[:, :, self._EnMAP_bandIdx] 

73 

74 # transform from sensor to map geometry to make it usable for tie point detection 

75 # -> co-registration runs at 15m resolution to minimize information loss from resampling 

76 # -> a UTM projection is used to have shift length in meters 

77 self._EnMAP_Im.logger.info('Temporarily transforming EnMAP band %d to map geometry for co-registration.' 

78 % (self._EnMAP_bandIdx + 1)) 

79 GT = Geometry_Transformer(lons=self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx], 

80 lats=self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx], 

81 backend='gdal', 

82 resamp_alg='bilinear', 

83 nprocs=self.cfg.CPUs) 

84 

85 self._EnMAP_band = \ 

86 GeoArray(*GT.to_map_geometry(enmap_band_sensorgeo, 

87 tgt_prj=self._get_optimal_utm_epsg(), 

88 tgt_coordgrid=((0, 15), (0, -15)), 

89 src_nodata=self._EnMAP_Im.vnir.data.nodata, 

90 tgt_nodata=0 

91 ), 

92 nodata=0) 

93 

94 return self._EnMAP_band 

95 

96 def _get_enmap_mask_for_matching(self) \ 

97 -> GeoArray: 

98 """Return the EnMAP mask to be used in co-registration in UTM projection at 15m resolution.""" 

99 # use the water mask 

100 enmap_mask_sensorgeo = self._EnMAP_Im.vnir.mask_landwater[:] == 2 # 2 is water here 

101 

102 # transform from sensor to map geometry to make it usable for tie point detection 

103 self._EnMAP_Im.logger.info('Temporarily transforming EnMAP water mask to map geometry for co-registration.') 

104 GT = Geometry_Transformer(lons=self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx], 

105 lats=self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx], 

106 backend='gdal', 

107 resamp_alg='nearest', 

108 nprocs=self.cfg.CPUs) 

109 

110 self._EnMAP_mask = \ 

111 GeoArray(*GT.to_map_geometry(enmap_mask_sensorgeo, 

112 tgt_prj=self._get_optimal_utm_epsg(), 

113 tgt_coordgrid=((0, 15), (0, -15)), 

114 src_nodata=0, # 0=background 

115 tgt_nodata=0 

116 ), 

117 nodata=0) 

118 

119 return self._EnMAP_mask 

120 

121 def _get_optimal_utm_epsg(self) -> int: 

122 """Return the EPSG code of the UTM zone that optimally covers the given EnMAP image.""" 

123 x, y = [c.tolist()[0] for c in self._EnMAP_Im.meta.vnir.ll_mapPoly.centroid.xy] 

124 

125 return int(f'{326 if y > 0 else 327}{get_utm_zone(x)}') 

126 

127 @staticmethod 

128 def _get_suitable_nodata_value(arr: np.ndarray): 

129 """Get a suitable nodata value for the input array, which is not contained in the array data.""" 

130 dtype = str(np.dtype(arr.dtype)) 

131 try: 

132 # use a suitable nodata value 

133 nodata = dict(int8=-128, uint8=0, int16=-9999, uint16=9999, int32=-9999, 

134 uint32=9999, float32=-9999., float64=-9999.)[dtype] 

135 if nodata not in arr: 

136 return nodata 

137 else: 

138 # use a suitable alternative nodata value 

139 alt_nodata = dict(int8=127, uint8=255, int16=32767, uint16=65535, int32=65535, 

140 uint32=65535, float32=65535, float64=65535)[dtype] 

141 if alt_nodata not in arr: 

142 return alt_nodata 

143 except AttributeError: 

144 return None 

145 

146 def _get_reference_band_for_matching(self) \ 

147 -> GeoArray: 

148 """Return the reference image band to be used in co-registration in UTM projection at 15m resolution.""" 

149 if self._EnMAP_band is None: 

150 raise RuntimeError('To prepare the reference image band, ' 

151 'the EnMAP band for matching needs to be computed first.') 

152 

153 self._EnMAP_Im.logger.info('Preparing reference image for co-registration.') 

154 

155 try: 

156 # get input nodata value if there is one defined in the metadata 

157 src_nodata = GeoArray(self.cfg.path_reference_image)._nodata # noqa 

158 

159 # only use the first band of the reference image for co-registration 

160 # TODO: select the most similar wavelength if CWLs are contained in the metadata 

161 # and provide a user option to specify the band 

162 src_ds = gdal.Translate('', self.cfg.path_reference_image, format='MEM', bandList=[1]) 

163 

164 # set up the target dataset and coordinate grid 

165 driver = gdal.GetDriverByName('MEM') 

166 rows, cols = self._EnMAP_band.shape 

167 dst_ds = driver.Create('', cols, rows, 1, gdal.GDT_Float32) 

168 dst_ds.SetProjection(self._EnMAP_band.prj) 

169 dst_ds.SetGeoTransform(self._EnMAP_band.gt) 

170 if src_nodata is not None: 

171 dst_ds.GetRasterBand(1).SetNoDataValue(int(src_nodata)) 

172 dst_ds.GetRasterBand(1).Fill(int(src_nodata)) 

173 

174 # reproject the needed subset from the reference image to a UTM 15m grid (like self._EnMAP_band) 

175 cb = ProgressBar(prefix='Warping progress ') if not self.cfg.disable_progress_bars else None 

176 gdal.SetConfigOption('GDAL_NUM_THREADS', f'{self.cfg.CPUs}') 

177 gdal.ReprojectImage( 

178 src_ds, 

179 dst_ds, 

180 src_ds.GetProjection(), 

181 self._EnMAP_band.prj, gdal.GRA_Cubic, 

182 callback=cb 

183 ) 

184 out_data = dst_ds.GetRasterBand(1).ReadAsArray() 

185 

186 finally: 

187 del src_ds 

188 del dst_ds 

189 gdal.SetConfigOption('GDAL_NUM_THREADS', None) 

190 

191 self._ref_band_prep = GeoArray(out_data, self._EnMAP_band.gt, self._EnMAP_band.prj, nodata=src_nodata) 

192 

193 # try to set a meaningful nodata value if it cannot be auto-detected 

194 # -> only needed for AROSICS where setting a nodata value avoids warnings 

195 if self._ref_band_prep.nodata is None: 

196 self._ref_band_prep.nodata = self._get_suitable_nodata_value(out_data) 

197 

198 return self._ref_band_prep 

199 

200 def _compute_tie_points(self): 

201 # avoid to run RANSAC within AROSICS if more than 50 lines of a gapfill image were appended 

202 # (since the RPC coefficients are computed for the main image, there may be decreasing geometry accuracy with 

203 # increasing distance from the main image) 

204 if os.path.exists(self.cfg.path_l1b_enmap_image_gapfill) and \ 

205 (self.cfg.n_lines_to_append is None or self.cfg.n_lines_to_append > 50): 

206 tieP_filter_level = 2 

207 else: 

208 tieP_filter_level = 3 

209 

210 # compute tie points within AROSICS 

211 CRL = COREG_LOCAL(self._ref_band_prep, 

212 self._EnMAP_band, 

213 grid_res=40, 

214 max_shift=10, # 5 EnMAP pixels (co-registration is running at 15m UTM grid) 

215 nodata=(self._ref_Im.nodata, 0), 

216 footprint_poly_tgt=reproject_shapelyGeometry(self._EnMAP_Im.meta.vnir.ll_mapPoly, 

217 4326, self._EnMAP_band.epsg), 

218 mask_baddata_tgt=self._EnMAP_mask, 

219 tieP_filter_level=tieP_filter_level, 

220 progress=self.cfg.disable_progress_bars is False 

221 ) 

222 TPG = CRL.tiepoint_grid 

223 # CRL.view_CoRegPoints(shapes2plot='vectors', hide_filtered=False, figsize=(20, 20), 

224 # savefigPath='/home/gfz-fe/scheffler/temp/EnPT/Archachon_AROSICS_tiepoints.png') 

225 

226 valid_tiepoints = TPG.CoRegPoints_table[TPG.CoRegPoints_table.OUTLIER.__eq__(False)].copy() 

227 

228 return valid_tiepoints 

229 

230 @staticmethod 

231 def _interpolate_tiepoints_into_space(tiepoints, outshape, metric='ABS_SHIFT'): 

232 rows = np.array(tiepoints.Y_IM) 

233 cols = np.array(tiepoints.X_IM) 

234 data = np.array(tiepoints[metric]) 

235 

236 from time import time 

237 t0 = time() 

238 

239 # https://github.com/agile-geoscience/xlines/blob/master/notebooks/11_Gridding_map_data.ipynb 

240 

241 from scipy.interpolate import Rbf 

242 # f = Rbf(cols, rows, data, function='linear') 

243 # f = Rbf(cols, rows, data) 

244 # data_full = f(*np.meshgrid(np.arange(outshape[1]), 

245 # np.arange(outshape[0]))) 

246 

247 # rows_lowres = np.arange(0, outshape[0] + 10, 10) 

248 # cols_lowres = np.arange(0, outshape[1] + 10, 10) 

249 rows_lowres = np.arange(0, outshape[0] + 5, 5) 

250 cols_lowres = np.arange(0, outshape[1] + 5, 5) 

251 f = Rbf(cols, rows, data) 

252 data_interp_lowres = f(*np.meshgrid(cols_lowres, rows_lowres)) 

253 

254 # https://stackoverflow.com/questions/24978052/interpolation-over-regular-grid-in-python 

255 # from sklearn.gaussian_process import GaussianProcess 

256 # gp = GaussianProcess(theta0=0.1, thetaL=.001, thetaU=1., nugget=0.01) 

257 # gp.fit(X=np.column_stack([rr[vals], cc[vals]]), y=M[vals]) 

258 # rr_cc_as_cols = np.column_stack([rr.flatten(), cc.flatten()]) 

259 # interpolated = gp.predict(rr_cc_as_cols).reshape(M.shape) 

260 

261 from scipy.interpolate import RegularGridInterpolator 

262 RGI = RegularGridInterpolator(points=[cols_lowres, rows_lowres], 

263 values=data_interp_lowres.T, # must be in shape [x, y] 

264 method='linear', 

265 bounds_error=False) 

266 rows_full = np.arange(outshape[0]) 

267 cols_full = np.arange(outshape[1]) 

268 data_full = RGI(np.dstack(np.meshgrid(cols_full, rows_full))) 

269 

270 print('interpolation runtime: %.2fs' % (time() - t0)) 

271 

272 # from matplotlib import pyplot as plt 

273 # plt.figure() 

274 # im = plt.imshow(data_full) 

275 # plt.colorbar(im) 

276 # plt.scatter(cols, rows, c=data, edgecolors='black') 

277 # plt.title(metric) 

278 # plt.show() 

279 

280 return data_full 

281 

282 def optimize_geolayer(self, 

283 enmap_ImageL1: EnMAPL1Product_SensorGeo): 

284 self._EnMAP_Im = enmap_ImageL1 

285 self._get_enmap_band_for_matching() 

286 self._get_enmap_mask_for_matching() 

287 self._get_reference_band_for_matching() 

288 

289 enmap_ImageL1.logger.info('Computing tie points between the EnMAP image and the given spatial reference image.') 

290 tiepoints = self._compute_tie_points() 

291 

292 enmap_ImageL1.logger.info('Generating misregistration array.') 

293 xshift_map = self._interpolate_tiepoints_into_space(tiepoints, 

294 self._EnMAP_band.shape, 

295 metric='X_SHIFT_M') 

296 yshift_map = self._interpolate_tiepoints_into_space(tiepoints, 

297 self._EnMAP_band.shape, 

298 metric='Y_SHIFT_M') 

299 

300 ULx, ULy = self._EnMAP_band.box.boxMapXY[0] 

301 xgsd, ygsd = self._EnMAP_band.xgsd, self._EnMAP_band.ygsd 

302 rows, cols = self._EnMAP_band.shape 

303 xgrid_map, ygrid_map = np.meshgrid(np.arange(ULx, ULx + cols * xgsd, xgsd), 

304 np.arange(ULy, ULy - rows * ygsd, -ygsd)) 

305 

306 xgrid_map_coreg = xgrid_map + xshift_map 

307 ygrid_map_coreg = ygrid_map + yshift_map 

308 

309 # transform map to sensor geometry 

310 enmap_ImageL1.logger.info('Transforming spatial optimization results back to sensor geometry.') 

311 lons_band = self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx] 

312 lats_band = self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx] 

313 GT = Geometry_Transformer(lons=lons_band, 

314 lats=lats_band, 

315 backend='gdal', 

316 resamp_alg='bilinear', 

317 nprocs=self.cfg.CPUs) 

318 

319 geolayer_sensorgeo = \ 

320 GT.to_sensor_geometry(GeoArray(np.dstack([xgrid_map_coreg, 

321 ygrid_map_coreg]), 

322 geotransform=self._EnMAP_band.gt, 

323 projection=self._EnMAP_band.prj), 

324 tgt_nodata=0) 

325 

326 enmap_ImageL1.logger.info('Applying results of spatial optimization to geolayer.') 

327 lons_coreg, lats_coreg = transform_coordArray(prj_src=self._ref_band_prep.prj, 

328 prj_tgt=EPSG2WKT(4326), 

329 Xarr=geolayer_sensorgeo[:, :, 0], 

330 Yarr=geolayer_sensorgeo[:, :, 1]) 

331 

332 diffs_lons_coreg = lons_band - lons_coreg 

333 diffs_lats_coreg = lats_band - lats_coreg 

334 

335 # enmap_ImageL1.meta.vnir.lons -= diffs_lons_coreg[:, :, np.newaxis] 

336 # enmap_ImageL1.meta.vnir.lats -= diffs_lats_coreg[:, :, np.newaxis] 

337 # enmap_ImageL1.meta.swir.lons -= diffs_lons_coreg[:, :, np.newaxis] 

338 # enmap_ImageL1.meta.swir.lats -= diffs_lats_coreg[:, :, np.newaxis] 

339 enmap_ImageL1.meta.vnir.lons = enmap_ImageL1.meta.vnir.lons - diffs_lons_coreg[:, :, np.newaxis] 

340 enmap_ImageL1.meta.vnir.lats = enmap_ImageL1.meta.vnir.lats - diffs_lats_coreg[:, :, np.newaxis] 

341 enmap_ImageL1.meta.swir.lons = enmap_ImageL1.meta.swir.lons - diffs_lons_coreg[:, :, np.newaxis] 

342 enmap_ImageL1.meta.swir.lats = enmap_ImageL1.meta.swir.lats - diffs_lats_coreg[:, :, np.newaxis] 

343 

344 return enmap_ImageL1