Coverage for enpt/processors/dead_pixel_correction/dead_pixel_correction.py: 91%
144 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-07 11:39 +0000
1# -*- coding: utf-8 -*-
3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data
4#
5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler
6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de),
7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de)
8#
9# This software was developed within the context of the EnMAP project supported
10# by the DLR Space Administration with funds of the German Federal Ministry of
11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag:
12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG.
13#
14# This program is free software: you can redistribute it and/or modify it under
15# the terms of the GNU General Public License as published by the Free Software
16# Foundation, either version 3 of the License, or (at your option) any later
17# version. Please note the following exception: `EnPT` depends on tqdm, which
18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files
19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore".
20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE.
21#
22# This program is distributed in the hope that it will be useful, but WITHOUT
23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
25# details.
26#
27# You should have received a copy of the GNU Lesser General Public License along
28# with this program. If not, see <https://www.gnu.org/licenses/>.
30"""EnPT 'dead pixel correction' module.
32Performs the Dead Pixel Correction using a linear interpolation in spectral dimension.
33"""
34from typing import Union
35from numbers import Number # noqa: F401
36import logging
37from warnings import filterwarnings
39import numpy as np
40import numpy_indexed as npi
41from multiprocessing import Pool, cpu_count
42from scipy.interpolate import griddata, make_interp_spline
43filterwarnings("ignore", "\nPyArrow", DeprecationWarning) # mute pandas warning
44from pandas import DataFrame # noqa: E402
45from geoarray import GeoArray # noqa: E402
47__author__ = 'Daniel Scheffler'
50class Dead_Pixel_Corrector(object):
51 """EnPT Dead Pixel Correction class.
53 The EnPT dead pixel correction uses the pixel masks provided by DLR and interpolates the EnMAP image
54 data at the indicated dead pixel positions. It supports two interpolation algorithms:
56 1. spectral interpolation
57 * Interpolates the data in the spectral domain.
58 * Points outside the data range are extrapolated.
59 2. spatial interpolation
60 * Interpolates the data spatially.
61 * Remaining missing data positions (e.g., outermost columns) are spectrally interpolated.
62 """
64 def __init__(self,
65 algorithm: str = 'spatial',
66 interp_spectral: str = 'linear',
67 interp_spatial: str = 'linear',
68 CPUs: int = None,
69 logger: logging.Logger = None):
70 """Get an instance of Dead_Pixel_Corrector.
72 :param algorithm: algorithm how to correct dead pixels
73 'spectral': interpolate in the spectral domain
74 'spatial': interpolate in the spatial domain
75 :param interp_spectral: spectral interpolation algorithm (‘linear’, ‘quadratic’, ‘cubic’)
76 :param interp_spatial: spatial interpolation algorithm ('linear', 'bilinear', 'cubic', 'spline')
77 :param CPUs: number of CPUs to use for interpolation (only relevant if algorithm = 'spatial')
78 :param logger:
79 """
80 self.algorithm = algorithm
81 self.interp_alg_spectral = interp_spectral
82 self.interp_alg_spatial = interp_spatial
83 self.CPUs = CPUs or cpu_count()
84 self.logger = logger or logging.getLogger()
86 @staticmethod
87 def _validate_inputs(image2correct: GeoArray,
88 deadpixel_map: GeoArray):
89 if deadpixel_map.ndim == 2:
90 if (image2correct.bands, image2correct.columns) != deadpixel_map.shape:
91 raise ValueError('The given image to be corrected (shape: %s) requires a dead column map with shape '
92 '(%s, %s). Received %s.'
93 % (image2correct.shape, image2correct.bands,
94 image2correct.columns, deadpixel_map.shape))
95 elif deadpixel_map.ndim == 3:
96 if image2correct.shape != deadpixel_map.shape:
97 raise ValueError('The given image to be corrected (shape: %s) requires a dead pixel map with equal '
98 'shape. Received %s.' % (image2correct.shape, deadpixel_map.shape))
99 else:
100 raise ValueError('Unexpected number of dimensions of dead column map.')
102 def _interpolate_nodata_spectrally(self,
103 image2correct: GeoArray,
104 deadpixel_map: GeoArray):
105 assert deadpixel_map.ndim == 3, "3D dead pixel map expected."
106 if deadpixel_map.shape != image2correct.shape:
107 raise ValueError("Dead pixel map and image to be corrected must have equal shape.")
109 image_corrected = interp_nodata_along_axis(image2correct, axis=2, nodata=deadpixel_map[:],
110 method=self.interp_alg_spectral)
112 return image_corrected
114 def _interpolate_nodata_spatially(self,
115 image2correct: GeoArray,
116 deadpixel_map: GeoArray):
117 assert deadpixel_map.ndim == 3, "3D dead pixel map expected."
118 if deadpixel_map.shape != image2correct.shape:
119 raise ValueError("Dead pixel map and image to be corrected must have equal shape.")
121 band_indices_with_nodata = np.argwhere(np.any(np.any(deadpixel_map[:], axis=0), axis=0)).flatten()
122 image_sub = image2correct[:, :, band_indices_with_nodata]
123 deadpixel_map_sub = deadpixel_map[:, :, band_indices_with_nodata]
125 kw = dict(method=self.interp_alg_spatial, fill_value=np.nan, implementation='pandas', CPUs=self.CPUs)
127 # correct dead columns
128 image_sub_interp = interp_nodata_spatially_3d(image_sub, axis=1, nodata=deadpixel_map_sub, **kw)
130 # correct dead rows
131 if np.isnan(image_sub_interp).any():
132 image_sub_interp = interp_nodata_spatially_3d(image_sub_interp, axis=0,
133 nodata=np.isnan(image_sub_interp), **kw)
135 image2correct[:, :, band_indices_with_nodata] = image_sub_interp
137 # correct remaining nodata by spectral interpolation (e.g., outermost columns)
138 if np.isnan(image2correct).any():
139 image2correct = interp_nodata_along_axis(image2correct, axis=2, nodata=np.isnan(image2correct),
140 method=self.interp_alg_spectral)
142 return image2correct
144 def correct(self,
145 image2correct: Union[np.ndarray, GeoArray],
146 deadpixel_map: Union[np.ndarray, GeoArray]):
147 """Run the dead pixel correction.
149 :param image2correct: image to correct
150 :param deadpixel_map: dead pixel map (2D (bands x columns) or 3D (rows x columns x bands)
151 :return: corrected image
152 """
153 image2correct = GeoArray(image2correct) if not isinstance(image2correct, GeoArray) else image2correct
155 self._validate_inputs(image2correct, deadpixel_map)
157 if 1 in list(np.unique(deadpixel_map)):
158 if deadpixel_map.ndim == 2:
159 deadcolumn_map = deadpixel_map
161 # compute dead pixel percentage
162 prop_dp_anyband = \
163 np.any(deadcolumn_map, axis=0).sum() * image2correct.shape[0] / np.dot(*image2correct.shape[:2])
164 prop_dp = deadcolumn_map.sum() * image2correct.shape[0] / image2correct.size
166 # convert 2D deadcolumn_map to 3D deadpixel_map
167 B, C = deadcolumn_map.shape
168 deadpixel_map = np.empty((image2correct.shape[0], C, B), bool)
169 deadpixel_map[:, :, :] = deadcolumn_map.T
171 else:
172 # compute dead pixel percentage
173 prop_dp_anyband = np.any(deadpixel_map, axis=2).sum() / np.dot(*image2correct.shape[:2])
174 prop_dp = deadpixel_map.sum() / image2correct.size
176 self.logger.info('Percentage of defective pixels: %.2f' % (prop_dp * 100))
177 self.logger.debug('Percentage of pixels with a defect in any band: %.2f' % (prop_dp_anyband * 100))
179 # run correction
180 if self.algorithm == 'spectral':
181 return self._interpolate_nodata_spectrally(image2correct, deadpixel_map)
182 else:
183 return self._interpolate_nodata_spatially(image2correct, deadpixel_map)
185 else:
186 self.logger.info("Dead pixel correction skipped because dead pixel mask labels no pixels as 'defective'.")
187 return image2correct
190def _get_baddata_mask(data: np.ndarray,
191 nodata: Union[np.ndarray, Number] = np.nan,
192 transpose_inNodata: bool = False):
193 if isinstance(nodata, np.ndarray):
194 badmask = nodata.T if transpose_inNodata else nodata
196 if badmask.shape != data.shape:
197 raise ValueError('No-data mask and data must have the same shape.')
199 else:
200 badmask = ~np.isfinite(data) if ~np.isfinite(nodata) else data == nodata
202 return badmask
205def interp_nodata_along_axis_2d(data_2d: np.ndarray,
206 axis: int = 0,
207 nodata: Union[np.ndarray, Number] = np.nan,
208 method: str = 'linear',
209 **kw):
210 """Interpolate a 2D array along the given axis (based on scipy.interpolate.make_interp_spline).
212 :param data_2d: data to interpolate
213 :param axis: axis to interpolate (0: along columns; 1: along rows)
214 :param nodata: nodata array in the shape of data or nodata value
215 :param method: interpolation method (‘linear’, ‘quadratic’, ‘cubic’)
216 :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline
217 :return: interpolated array
218 """
219 if data_2d.ndim != 2:
220 raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim)
221 if axis > data_2d.ndim:
222 raise ValueError("axis=%d is out of bounds for data with %d dimensions." % (axis, data_2d.ndim))
223 if method not in ['linear', 'quadratic', 'cubic']:
224 raise ValueError(f"'{method}' is not a valid interpolation method. "
225 f"Choose between 'linear', 'quadratic', and 'cubic'.")
226 degree = 1 if method == 'linear' else 2 if method == 'quadratic' else 3
228 data_2d = data_2d if axis == 1 else data_2d.T
230 badmask_full = _get_baddata_mask(data_2d, nodata, transpose_inNodata=axis == 0)
232 # call 1D interpolation vectorized
233 # => group the dataset by rows that have nodata at the same column position
234 # => remember the row positions, call the interpolation for these rows at once (vectorized)
235 # and substitute the original data at the previously grouped row positions
236 groups_unique_rows = npi.group_by(badmask_full).split(np.arange(len(badmask_full)))
238 for indices_unique_rows in groups_unique_rows:
239 badmask_grouped_rows = badmask_full[indices_unique_rows, :]
241 if np.any(badmask_grouped_rows[0, :]):
242 badpos = np.argwhere(badmask_grouped_rows[0, :]).flatten()
243 goodpos = np.delete(np.arange(data_2d.shape[1]), badpos)
245 if goodpos.size > 1:
246 data_2d_grouped_rows = data_2d[indices_unique_rows]
247 data_2d_grouped_rows[:, badpos] = \
248 make_interp_spline(goodpos, data_2d_grouped_rows[:, goodpos], axis=1, k=degree, **kw)(badpos)
250 data_2d[indices_unique_rows, :] = data_2d_grouped_rows
252 return data_2d if axis == 1 else data_2d.T
255def interp_nodata_along_axis(data,
256 axis=0,
257 nodata: Union[np.ndarray, Number] = np.nan,
258 method: str = 'linear',
259 **kw):
260 """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.make_interp_spline).
262 :param data: data to interpolate
263 :param axis: axis to interpolate (0: along columns; 1: along rows, 2: along bands)
264 :param nodata: nodata array in the shape of data or nodata value
265 :param method: interpolation method (‘linear’, 'quadratic', 'cubic')
266 :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline
267 :return: interpolated array
268 """
269 assert axis <= 2
270 if data.ndim not in [2, 3]:
271 raise ValueError('Expected a 2D or 3D array. Received a %dD array.' % data.ndim)
272 if isinstance(nodata, np.ndarray) and nodata.shape != data.shape:
273 raise ValueError('No-data mask and data must have the same shape.')
275 if data.ndim == 2:
276 return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, **kw)
278 else:
279 def reshape_input(In):
280 R, C, B = In.shape
281 return \
282 In.reshape(C, R * B) if axis == 0 else \
283 np.transpose(In, axes=[1, 0, 2]).reshape(C, R * B).T if axis == 1 else \
284 In.reshape(R * C, B)
286 def reshape_output(out):
287 return \
288 out.reshape(data.shape) if axis in [0, 2] else \
289 np.transpose(out.T.reshape(data.shape), axes=[1, 0, 2])
291 return \
292 reshape_output(
293 interp_nodata_along_axis_2d(
294 data_2d=reshape_input(data),
295 nodata=reshape_input(nodata) if isinstance(nodata, np.ndarray) else nodata,
296 axis=axis if axis != 2 else 1,
297 method=method, **kw))
300def interp_nodata_spatially_2d(data_2d: np.ndarray,
301 axis: int = 0,
302 nodata: Union[np.ndarray, Number] = np.nan,
303 method: str = 'linear',
304 fill_value: float = np.nan,
305 implementation: str = 'pandas'
306 ) -> np.ndarray:
307 """Interpolate a 2D array spatially.
309 NOTE: If data_2d contains NaN values that are unlabelled by a given nodata array,
310 they are also overwritten in the pandas implementation.
312 :param data_2d: data to interpolate
313 :param axis: axis to interpolate (0: along columns; 1: along rows)
314 :param nodata: nodata array in the shape of data or nodata value
315 :param method: interpolation method
316 - if implementation='scipy': ‘linear’, ‘nearest’, ‘cubic’
317 - if implementation='pandas': ‘linear’, ‘nearest’, 'slinear’, ‘quadratic’, ‘cubic’, etc.
318 :param fill_value: value to fill into positions where no interpolation is possible
319 :param implementation: 'scipy': interpolation based on scipy.interpolate.griddata
320 'pandas': interpolation based on pandas.core.resample.Resampler.interpolate
321 :return: interpolated array
322 """
323 assert axis < 2
324 if data_2d.ndim != 2:
325 raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim)
327 badmask_full = _get_baddata_mask(data_2d, nodata)
329 if badmask_full.any():
330 if implementation == 'scipy':
331 if axis == 0:
332 y, x = np.indices(data_2d.shape)
333 else:
334 x, y = np.indices(data_2d.shape)
336 data_2d[badmask_full] = \
337 griddata(np.array([x[~badmask_full], y[~badmask_full]]).T, # points we know
338 data_2d[~badmask_full], # values we know
339 np.array([x[badmask_full], y[badmask_full]]).T, # points to interpolate
340 method=method, fill_value=fill_value)
342 elif implementation == 'pandas':
343 data2int = data_2d.astype(float)
344 data2int[badmask_full] = np.nan
346 data_2d = np.array(DataFrame(data2int)
347 .interpolate(method=method, axis=axis)).astype(data_2d.dtype)
349 if np.isfinite(fill_value):
350 mask_nan = np.isnan(data_2d)
351 if True in mask_nan:
352 data_2d[mask_nan] = fill_value
354 else:
355 raise ValueError(implementation, 'Unknown implementation.')
357 return data_2d
360def interp_nodata_spatially_3d(data_3d: np.ndarray,
361 axis: int = 0,
362 nodata: Union[np.ndarray, Number] = np.nan,
363 method: str = 'linear',
364 fill_value: float = np.nan,
365 implementation: str = 'pandas',
366 CPUs: int = None
367 ) -> np.ndarray:
368 """Interpolate a 3D array spatially, band-for-band.
370 :param data_3d: data to interpolate
371 :param axis: axis to interpolate (0: along columns; 1: along rows)
372 :param nodata: nodata array in the shape of data or nodata value
373 :param method: interpolation method
374 - if implementation='scipy': ‘linear’, ‘nearest’, ‘cubic’
375 - if implementation='pandas': ‘linear’, ‘nearest’, 'slinear’, ‘quadratic’, ‘cubic’, etc.
376 :param fill_value: value to fill into positions where no interpolation is possible
377 :param implementation: 'scipy': interpolation based on scipy.interpolate.griddata
378 'pandas': interpolation based on pandas.core.resample.Resampler.interpolate
379 :param CPUs: number of CPUs to use
380 :return: interpolated array
381 """
382 assert axis < 2
384 badmask_full = _get_baddata_mask(data_3d, nodata)
386 if CPUs > 1:
387 with Pool(CPUs or cpu_count()) as pool:
388 args = [[data_3d[:, :, band], axis, badmask_full[:, :, band], method, fill_value, implementation]
389 for band in range(data_3d.shape[2])]
390 results = pool.starmap(interp_nodata_spatially_2d, args)
392 pool.close() # needed for coverage to work in multiprocessing
393 pool.join()
395 return np.dstack(results)
397 else:
398 return \
399 np.dstack([interp_nodata_spatially_2d(data_3d[:, :, band], axis=axis,
400 nodata=badmask_full[:, :, band], method=method,
401 fill_value=fill_value, implementation=implementation)
402 for band in range(data_3d.shape[2])])