Coverage for enpt/processors/spatial_optimization/spatial_optimization.py: 93%
134 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
1# -*- coding: utf-8 -*-
3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data
4#
5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler
6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de),
7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de)
8#
9# This software was developed within the context of the EnMAP project supported
10# by the DLR Space Administration with funds of the German Federal Ministry of
11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag:
12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG.
13#
14# This program is free software: you can redistribute it and/or modify it under
15# the terms of the GNU General Public License as published by the Free Software
16# Foundation, either version 3 of the License, or (at your option) any later
17# version. Please note the following exception: `EnPT` depends on tqdm, which
18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
21#
22# This program is distributed in the hope that it will be useful, but WITHOUT
23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
25# details.
26#
27# You should have received a copy of the GNU Lesser General Public License along
28# with this program. If not, see <https://www.gnu.org/licenses/>.
30"""EnPT spatial optimization module.
32Adapts the EnMAP image geometry to a given Sentinel-2 L2A dataset.
33Fits the VNIR detector data to the reference image. Corrects for keystone.
34"""
36__author__ = 'Daniel Scheffler'
38import os
39from typing import Optional
41import numpy as np
42from osgeo import gdal # noqa
44from arosics import COREG_LOCAL
45from geoarray import GeoArray
46from py_tools_ds.geo.coord_trafo import reproject_shapelyGeometry, transform_coordArray, get_utm_zone
47from py_tools_ds.geo.projection import EPSG2WKT
48from py_tools_ds.processing.progress_mon import ProgressBar
50from ...options.config import EnPTConfig
51from ...model.images.images_sensorgeo import EnMAPL1Product_SensorGeo
52from ..spatial_transform import Geometry_Transformer
55class Spatial_Optimizer(object):
56 def __init__(self, config: EnPTConfig):
57 """Create an instance of Spatial_Optimizer."""
58 self.cfg = config
59 self._ref_Im: Optional[GeoArray, None] = GeoArray(self.cfg.path_reference_image)
61 self._EnMAP_Im: Optional[EnMAPL1Product_SensorGeo, None] = None
62 self._EnMAP_band: Optional[GeoArray, None] = None
63 self._EnMAP_mask: Optional[GeoArray, None] = None
64 self._ref_band_prep: Optional[GeoArray, None] = None
65 self._EnMAP_bandIdx = 39 # FIXME hardcoded
67 def _get_enmap_band_for_matching(self) \
68 -> GeoArray:
69 """Return the EnMAP band to be used in co-registration in UTM projection at 15m resolution."""
70 self._EnMAP_Im.logger.warning(f'Statically using band {self._EnMAP_bandIdx + 1} for co-registration.')
72 enmap_band_sensorgeo = self._EnMAP_Im.vnir.data[:, :, self._EnMAP_bandIdx]
74 # transform from sensor to map geometry to make it usable for tie point detection
75 # -> co-registration runs at 15m resolution to minimize information loss from resampling
76 # -> a UTM projection is used to have shift length in meters
77 self._EnMAP_Im.logger.info('Temporarily transforming EnMAP band %d to map geometry for co-registration.'
78 % (self._EnMAP_bandIdx + 1))
79 GT = Geometry_Transformer(lons=self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx],
80 lats=self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx],
81 backend='gdal',
82 resamp_alg='bilinear',
83 nprocs=self.cfg.CPUs)
85 self._EnMAP_band = \
86 GeoArray(*GT.to_map_geometry(enmap_band_sensorgeo,
87 tgt_prj=self._get_optimal_utm_epsg(),
88 tgt_coordgrid=((0, 15), (0, -15)),
89 src_nodata=self._EnMAP_Im.vnir.data.nodata,
90 tgt_nodata=0
91 ),
92 nodata=0)
94 return self._EnMAP_band
96 def _get_enmap_mask_for_matching(self) \
97 -> GeoArray:
98 """Return the EnMAP mask to be used in co-registration in UTM projection at 15m resolution."""
99 # use the water mask
100 enmap_mask_sensorgeo = self._EnMAP_Im.vnir.mask_landwater[:] == 2 # 2 is water here
102 # transform from sensor to map geometry to make it usable for tie point detection
103 self._EnMAP_Im.logger.info('Temporarily transforming EnMAP water mask to map geometry for co-registration.')
104 GT = Geometry_Transformer(lons=self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx],
105 lats=self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx],
106 backend='gdal',
107 resamp_alg='nearest',
108 nprocs=self.cfg.CPUs)
110 self._EnMAP_mask = \
111 GeoArray(*GT.to_map_geometry(enmap_mask_sensorgeo,
112 tgt_prj=self._get_optimal_utm_epsg(),
113 tgt_coordgrid=((0, 15), (0, -15)),
114 src_nodata=0, # 0=background
115 tgt_nodata=0
116 ),
117 nodata=0)
119 return self._EnMAP_mask
121 def _get_optimal_utm_epsg(self) -> int:
122 """Return the EPSG code of the UTM zone that optimally covers the given EnMAP image."""
123 x, y = [c.tolist()[0] for c in self._EnMAP_Im.meta.vnir.ll_mapPoly.centroid.xy]
125 return int(f'{326 if y > 0 else 327}{get_utm_zone(x)}')
127 @staticmethod
128 def _get_suitable_nodata_value(arr: np.ndarray):
129 """Get a suitable nodata value for the input array, which is not contained in the array data."""
130 dtype = str(np.dtype(arr.dtype))
131 try:
132 # use a suitable nodata value
133 nodata = dict(int8=-128, uint8=0, int16=-9999, uint16=9999, int32=-9999,
134 uint32=9999, float32=-9999., float64=-9999.)[dtype]
135 if nodata not in arr:
136 return nodata
137 else:
138 # use a suitable alternative nodata value
139 alt_nodata = dict(int8=127, uint8=255, int16=32767, uint16=65535, int32=65535,
140 uint32=65535, float32=65535, float64=65535)[dtype]
141 if alt_nodata not in arr:
142 return alt_nodata
143 except AttributeError:
144 return None
146 def _get_reference_band_for_matching(self) \
147 -> GeoArray:
148 """Return the reference image band to be used in co-registration in UTM projection at 15m resolution."""
149 if self._EnMAP_band is None:
150 raise RuntimeError('To prepare the reference image band, '
151 'the EnMAP band for matching needs to be computed first.')
153 self._EnMAP_Im.logger.info('Preparing reference image for co-registration.')
155 try:
156 # get input nodata value if there is one defined in the metadata
157 src_nodata = GeoArray(self.cfg.path_reference_image)._nodata # noqa
159 # only use the first band of the reference image for co-registration
160 # TODO: select the most similar wavelength if CWLs are contained in the metadata
161 # and provide a user option to specify the band
162 src_ds = gdal.Translate('', self.cfg.path_reference_image, format='MEM', bandList=[1])
164 # set up the target dataset and coordinate grid
165 driver = gdal.GetDriverByName('MEM')
166 rows, cols = self._EnMAP_band.shape
167 dst_ds = driver.Create('', cols, rows, 1, gdal.GDT_Float32)
168 dst_ds.SetProjection(self._EnMAP_band.prj)
169 dst_ds.SetGeoTransform(self._EnMAP_band.gt)
170 if src_nodata is not None:
171 dst_ds.GetRasterBand(1).SetNoDataValue(int(src_nodata))
172 dst_ds.GetRasterBand(1).Fill(int(src_nodata))
174 # reproject the needed subset from the reference image to a UTM 15m grid (like self._EnMAP_band)
175 cb = ProgressBar(prefix='Warping progress ') if not self.cfg.disable_progress_bars else None
176 gdal.SetConfigOption('GDAL_NUM_THREADS', f'{self.cfg.CPUs}')
177 gdal.ReprojectImage(
178 src_ds,
179 dst_ds,
180 src_ds.GetProjection(),
181 self._EnMAP_band.prj, gdal.GRA_Cubic,
182 callback=cb
183 )
184 out_data = dst_ds.GetRasterBand(1).ReadAsArray()
186 finally:
187 del src_ds
188 del dst_ds
189 gdal.SetConfigOption('GDAL_NUM_THREADS', None)
191 self._ref_band_prep = GeoArray(out_data, self._EnMAP_band.gt, self._EnMAP_band.prj, nodata=src_nodata)
193 # try to set a meaningful nodata value if it cannot be auto-detected
194 # -> only needed for AROSICS where setting a nodata value avoids warnings
195 if self._ref_band_prep.nodata is None:
196 self._ref_band_prep.nodata = self._get_suitable_nodata_value(out_data)
198 return self._ref_band_prep
200 def _compute_tie_points(self):
201 # avoid to run RANSAC within AROSICS if more than 50 lines of a gapfill image were appended
202 # (since the RPC coefficients are computed for the main image, there may be decreasing geometry accuracy with
203 # increasing distance from the main image)
204 if os.path.exists(self.cfg.path_l1b_enmap_image_gapfill) and \
205 (self.cfg.n_lines_to_append is None or self.cfg.n_lines_to_append > 50):
206 tieP_filter_level = 2
207 else:
208 tieP_filter_level = 3
210 # compute tie points within AROSICS
211 CRL = COREG_LOCAL(self._ref_band_prep,
212 self._EnMAP_band,
213 grid_res=40,
214 max_shift=10, # 5 EnMAP pixels (co-registration is running at 15m UTM grid)
215 nodata=(self._ref_Im.nodata, 0),
216 footprint_poly_tgt=reproject_shapelyGeometry(self._EnMAP_Im.meta.vnir.ll_mapPoly,
217 4326, self._EnMAP_band.epsg),
218 mask_baddata_tgt=self._EnMAP_mask,
219 tieP_filter_level=tieP_filter_level,
220 progress=self.cfg.disable_progress_bars is False
221 )
222 TPG = CRL.tiepoint_grid
223 # CRL.view_CoRegPoints(shapes2plot='vectors', hide_filtered=False, figsize=(20, 20),
224 # savefigPath='/home/gfz-fe/scheffler/temp/EnPT/Archachon_AROSICS_tiepoints.png')
226 valid_tiepoints = TPG.CoRegPoints_table[TPG.CoRegPoints_table.OUTLIER.__eq__(False)].copy()
228 return valid_tiepoints
230 @staticmethod
231 def _interpolate_tiepoints_into_space(tiepoints, outshape, metric='ABS_SHIFT'):
232 rows = np.array(tiepoints.Y_IM)
233 cols = np.array(tiepoints.X_IM)
234 data = np.array(tiepoints[metric])
236 from time import time
237 t0 = time()
239 # https://github.com/agile-geoscience/xlines/blob/master/notebooks/11_Gridding_map_data.ipynb
241 from scipy.interpolate import Rbf
242 # f = Rbf(cols, rows, data, function='linear')
243 # f = Rbf(cols, rows, data)
244 # data_full = f(*np.meshgrid(np.arange(outshape[1]),
245 # np.arange(outshape[0])))
247 # rows_lowres = np.arange(0, outshape[0] + 10, 10)
248 # cols_lowres = np.arange(0, outshape[1] + 10, 10)
249 rows_lowres = np.arange(0, outshape[0] + 5, 5)
250 cols_lowres = np.arange(0, outshape[1] + 5, 5)
251 f = Rbf(cols, rows, data)
252 data_interp_lowres = f(*np.meshgrid(cols_lowres, rows_lowres))
254 # https://stackoverflow.com/questions/24978052/interpolation-over-regular-grid-in-python
255 # from sklearn.gaussian_process import GaussianProcess
256 # gp = GaussianProcess(theta0=0.1, thetaL=.001, thetaU=1., nugget=0.01)
257 # gp.fit(X=np.column_stack([rr[vals], cc[vals]]), y=M[vals])
258 # rr_cc_as_cols = np.column_stack([rr.flatten(), cc.flatten()])
259 # interpolated = gp.predict(rr_cc_as_cols).reshape(M.shape)
261 from scipy.interpolate import RegularGridInterpolator
262 RGI = RegularGridInterpolator(points=[cols_lowres, rows_lowres],
263 values=data_interp_lowres.T, # must be in shape [x, y]
264 method='linear',
265 bounds_error=False)
266 rows_full = np.arange(outshape[0])
267 cols_full = np.arange(outshape[1])
268 data_full = RGI(np.dstack(np.meshgrid(cols_full, rows_full)))
270 print('interpolation runtime: %.2fs' % (time() - t0))
272 # from matplotlib import pyplot as plt
273 # plt.figure()
274 # im = plt.imshow(data_full)
275 # plt.colorbar(im)
276 # plt.scatter(cols, rows, c=data, edgecolors='black')
277 # plt.title(metric)
278 # plt.show()
280 return data_full
282 def optimize_geolayer(self,
283 enmap_ImageL1: EnMAPL1Product_SensorGeo):
284 self._EnMAP_Im = enmap_ImageL1
285 self._get_enmap_band_for_matching()
286 self._get_enmap_mask_for_matching()
287 self._get_reference_band_for_matching()
289 enmap_ImageL1.logger.info('Computing tie points between the EnMAP image and the given spatial reference image.')
290 tiepoints = self._compute_tie_points()
292 enmap_ImageL1.logger.info('Generating misregistration array.')
293 xshift_map = self._interpolate_tiepoints_into_space(tiepoints,
294 self._EnMAP_band.shape,
295 metric='X_SHIFT_M')
296 yshift_map = self._interpolate_tiepoints_into_space(tiepoints,
297 self._EnMAP_band.shape,
298 metric='Y_SHIFT_M')
300 ULx, ULy = self._EnMAP_band.box.boxMapXY[0]
301 xgsd, ygsd = self._EnMAP_band.xgsd, self._EnMAP_band.ygsd
302 rows, cols = self._EnMAP_band.shape
303 xgrid_map, ygrid_map = np.meshgrid(np.arange(ULx, ULx + cols * xgsd, xgsd),
304 np.arange(ULy, ULy - rows * ygsd, -ygsd))
306 xgrid_map_coreg = xgrid_map + xshift_map
307 ygrid_map_coreg = ygrid_map + yshift_map
309 # transform map to sensor geometry
310 enmap_ImageL1.logger.info('Transforming spatial optimization results back to sensor geometry.')
311 lons_band = self._EnMAP_Im.meta.vnir.lons[:, :, self._EnMAP_bandIdx]
312 lats_band = self._EnMAP_Im.meta.vnir.lats[:, :, self._EnMAP_bandIdx]
313 GT = Geometry_Transformer(lons=lons_band,
314 lats=lats_band,
315 backend='gdal',
316 resamp_alg='bilinear',
317 nprocs=self.cfg.CPUs)
319 geolayer_sensorgeo = \
320 GT.to_sensor_geometry(GeoArray(np.dstack([xgrid_map_coreg,
321 ygrid_map_coreg]),
322 geotransform=self._EnMAP_band.gt,
323 projection=self._EnMAP_band.prj),
324 tgt_nodata=0)
326 enmap_ImageL1.logger.info('Applying results of spatial optimization to geolayer.')
327 lons_coreg, lats_coreg = transform_coordArray(prj_src=self._ref_band_prep.prj,
328 prj_tgt=EPSG2WKT(4326),
329 Xarr=geolayer_sensorgeo[:, :, 0],
330 Yarr=geolayer_sensorgeo[:, :, 1])
332 diffs_lons_coreg = lons_band - lons_coreg
333 diffs_lats_coreg = lats_band - lats_coreg
335 # enmap_ImageL1.meta.vnir.lons -= diffs_lons_coreg[:, :, np.newaxis]
336 # enmap_ImageL1.meta.vnir.lats -= diffs_lats_coreg[:, :, np.newaxis]
337 # enmap_ImageL1.meta.swir.lons -= diffs_lons_coreg[:, :, np.newaxis]
338 # enmap_ImageL1.meta.swir.lats -= diffs_lats_coreg[:, :, np.newaxis]
339 enmap_ImageL1.meta.vnir.lons = enmap_ImageL1.meta.vnir.lons - diffs_lons_coreg[:, :, np.newaxis]
340 enmap_ImageL1.meta.vnir.lats = enmap_ImageL1.meta.vnir.lats - diffs_lats_coreg[:, :, np.newaxis]
341 enmap_ImageL1.meta.swir.lons = enmap_ImageL1.meta.swir.lons - diffs_lons_coreg[:, :, np.newaxis]
342 enmap_ImageL1.meta.swir.lats = enmap_ImageL1.meta.swir.lats - diffs_lats_coreg[:, :, np.newaxis]
344 return enmap_ImageL1