Coverage for enpt/processors/dead_pixel_correction/dead_pixel_correction.py: 91%

144 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-03-07 11:39 +0000

1# -*- coding: utf-8 -*- 

2 

3# EnPT, EnMAP Processing Tool - A Python package for pre-processing of EnMAP Level-1B data 

4# 

5# Copyright (C) 2018-2024 Karl Segl (GFZ Potsdam, segl@gfz-potsdam.de), Daniel Scheffler 

6# (GFZ Potsdam, danschef@gfz-potsdam.de), Niklas Bohn (GFZ Potsdam, nbohn@gfz-potsdam.de), 

7# Stéphane Guillaso (GFZ Potsdam, stephane.guillaso@gfz-potsdam.de) 

8# 

9# This software was developed within the context of the EnMAP project supported 

10# by the DLR Space Administration with funds of the German Federal Ministry of 

11# Economic Affairs and Energy (on the basis of a decision by the German Bundestag: 

12# 50 EE 1529) and contributions from DLR, GFZ and OHB System AG. 

13# 

14# This program is free software: you can redistribute it and/or modify it under 

15# the terms of the GNU General Public License as published by the Free Software 

16# Foundation, either version 3 of the License, or (at your option) any later 

17# version. Please note the following exception: `EnPT` depends on tqdm, which 

18# is distributed under the Mozilla Public Licence (MPL) v2.0 except for the files 

19# "tqdm/_tqdm.py", "setup.py", "README.rst", "MANIFEST.in" and ".gitignore". 

20# Details can be found here: https://github.com/tqdm/tqdm/blob/master/LICENCE. 

21# 

22# This program is distributed in the hope that it will be useful, but WITHOUT 

23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 

24# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 

25# details. 

26# 

27# You should have received a copy of the GNU Lesser General Public License along 

28# with this program. If not, see <https://www.gnu.org/licenses/>. 

29 

30"""EnPT 'dead pixel correction' module. 

31 

32Performs the Dead Pixel Correction using a linear interpolation in spectral dimension. 

33""" 

34from typing import Union 

35from numbers import Number # noqa: F401 

36import logging 

37from warnings import filterwarnings 

38 

39import numpy as np 

40import numpy_indexed as npi 

41from multiprocessing import Pool, cpu_count 

42from scipy.interpolate import griddata, make_interp_spline 

43filterwarnings("ignore", "\nPyArrow", DeprecationWarning) # mute pandas warning 

44from pandas import DataFrame # noqa: E402 

45from geoarray import GeoArray # noqa: E402 

46 

47__author__ = 'Daniel Scheffler' 

48 

49 

50class Dead_Pixel_Corrector(object): 

51 """EnPT Dead Pixel Correction class. 

52 

53 The EnPT dead pixel correction uses the pixel masks provided by DLR and interpolates the EnMAP image 

54 data at the indicated dead pixel positions. It supports two interpolation algorithms: 

55 

56 1. spectral interpolation 

57 * Interpolates the data in the spectral domain. 

58 * Points outside the data range are extrapolated. 

59 2. spatial interpolation 

60 * Interpolates the data spatially. 

61 * Remaining missing data positions (e.g., outermost columns) are spectrally interpolated. 

62 """ 

63 

64 def __init__(self, 

65 algorithm: str = 'spatial', 

66 interp_spectral: str = 'linear', 

67 interp_spatial: str = 'linear', 

68 CPUs: int = None, 

69 logger: logging.Logger = None): 

70 """Get an instance of Dead_Pixel_Corrector. 

71 

72 :param algorithm: algorithm how to correct dead pixels 

73 'spectral': interpolate in the spectral domain 

74 'spatial': interpolate in the spatial domain 

75 :param interp_spectral: spectral interpolation algorithm (‘linear’, ‘quadratic’, ‘cubic’) 

76 :param interp_spatial: spatial interpolation algorithm ('linear', 'bilinear', 'cubic', 'spline') 

77 :param CPUs: number of CPUs to use for interpolation (only relevant if algorithm = 'spatial') 

78 :param logger: 

79 """ 

80 self.algorithm = algorithm 

81 self.interp_alg_spectral = interp_spectral 

82 self.interp_alg_spatial = interp_spatial 

83 self.CPUs = CPUs or cpu_count() 

84 self.logger = logger or logging.getLogger() 

85 

86 @staticmethod 

87 def _validate_inputs(image2correct: GeoArray, 

88 deadpixel_map: GeoArray): 

89 if deadpixel_map.ndim == 2: 

90 if (image2correct.bands, image2correct.columns) != deadpixel_map.shape: 

91 raise ValueError('The given image to be corrected (shape: %s) requires a dead column map with shape ' 

92 '(%s, %s). Received %s.' 

93 % (image2correct.shape, image2correct.bands, 

94 image2correct.columns, deadpixel_map.shape)) 

95 elif deadpixel_map.ndim == 3: 

96 if image2correct.shape != deadpixel_map.shape: 

97 raise ValueError('The given image to be corrected (shape: %s) requires a dead pixel map with equal ' 

98 'shape. Received %s.' % (image2correct.shape, deadpixel_map.shape)) 

99 else: 

100 raise ValueError('Unexpected number of dimensions of dead column map.') 

101 

102 def _interpolate_nodata_spectrally(self, 

103 image2correct: GeoArray, 

104 deadpixel_map: GeoArray): 

105 assert deadpixel_map.ndim == 3, "3D dead pixel map expected." 

106 if deadpixel_map.shape != image2correct.shape: 

107 raise ValueError("Dead pixel map and image to be corrected must have equal shape.") 

108 

109 image_corrected = interp_nodata_along_axis(image2correct, axis=2, nodata=deadpixel_map[:], 

110 method=self.interp_alg_spectral) 

111 

112 return image_corrected 

113 

114 def _interpolate_nodata_spatially(self, 

115 image2correct: GeoArray, 

116 deadpixel_map: GeoArray): 

117 assert deadpixel_map.ndim == 3, "3D dead pixel map expected." 

118 if deadpixel_map.shape != image2correct.shape: 

119 raise ValueError("Dead pixel map and image to be corrected must have equal shape.") 

120 

121 band_indices_with_nodata = np.argwhere(np.any(np.any(deadpixel_map[:], axis=0), axis=0)).flatten() 

122 image_sub = image2correct[:, :, band_indices_with_nodata] 

123 deadpixel_map_sub = deadpixel_map[:, :, band_indices_with_nodata] 

124 

125 kw = dict(method=self.interp_alg_spatial, fill_value=np.nan, implementation='pandas', CPUs=self.CPUs) 

126 

127 # correct dead columns 

128 image_sub_interp = interp_nodata_spatially_3d(image_sub, axis=1, nodata=deadpixel_map_sub, **kw) 

129 

130 # correct dead rows 

131 if np.isnan(image_sub_interp).any(): 

132 image_sub_interp = interp_nodata_spatially_3d(image_sub_interp, axis=0, 

133 nodata=np.isnan(image_sub_interp), **kw) 

134 

135 image2correct[:, :, band_indices_with_nodata] = image_sub_interp 

136 

137 # correct remaining nodata by spectral interpolation (e.g., outermost columns) 

138 if np.isnan(image2correct).any(): 

139 image2correct = interp_nodata_along_axis(image2correct, axis=2, nodata=np.isnan(image2correct), 

140 method=self.interp_alg_spectral) 

141 

142 return image2correct 

143 

144 def correct(self, 

145 image2correct: Union[np.ndarray, GeoArray], 

146 deadpixel_map: Union[np.ndarray, GeoArray]): 

147 """Run the dead pixel correction. 

148 

149 :param image2correct: image to correct 

150 :param deadpixel_map: dead pixel map (2D (bands x columns) or 3D (rows x columns x bands) 

151 :return: corrected image 

152 """ 

153 image2correct = GeoArray(image2correct) if not isinstance(image2correct, GeoArray) else image2correct 

154 

155 self._validate_inputs(image2correct, deadpixel_map) 

156 

157 if 1 in list(np.unique(deadpixel_map)): 

158 if deadpixel_map.ndim == 2: 

159 deadcolumn_map = deadpixel_map 

160 

161 # compute dead pixel percentage 

162 prop_dp_anyband = \ 

163 np.any(deadcolumn_map, axis=0).sum() * image2correct.shape[0] / np.dot(*image2correct.shape[:2]) 

164 prop_dp = deadcolumn_map.sum() * image2correct.shape[0] / image2correct.size 

165 

166 # convert 2D deadcolumn_map to 3D deadpixel_map 

167 B, C = deadcolumn_map.shape 

168 deadpixel_map = np.empty((image2correct.shape[0], C, B), bool) 

169 deadpixel_map[:, :, :] = deadcolumn_map.T 

170 

171 else: 

172 # compute dead pixel percentage 

173 prop_dp_anyband = np.any(deadpixel_map, axis=2).sum() / np.dot(*image2correct.shape[:2]) 

174 prop_dp = deadpixel_map.sum() / image2correct.size 

175 

176 self.logger.info('Percentage of defective pixels: %.2f' % (prop_dp * 100)) 

177 self.logger.debug('Percentage of pixels with a defect in any band: %.2f' % (prop_dp_anyband * 100)) 

178 

179 # run correction 

180 if self.algorithm == 'spectral': 

181 return self._interpolate_nodata_spectrally(image2correct, deadpixel_map) 

182 else: 

183 return self._interpolate_nodata_spatially(image2correct, deadpixel_map) 

184 

185 else: 

186 self.logger.info("Dead pixel correction skipped because dead pixel mask labels no pixels as 'defective'.") 

187 return image2correct 

188 

189 

190def _get_baddata_mask(data: np.ndarray, 

191 nodata: Union[np.ndarray, Number] = np.nan, 

192 transpose_inNodata: bool = False): 

193 if isinstance(nodata, np.ndarray): 

194 badmask = nodata.T if transpose_inNodata else nodata 

195 

196 if badmask.shape != data.shape: 

197 raise ValueError('No-data mask and data must have the same shape.') 

198 

199 else: 

200 badmask = ~np.isfinite(data) if ~np.isfinite(nodata) else data == nodata 

201 

202 return badmask 

203 

204 

205def interp_nodata_along_axis_2d(data_2d: np.ndarray, 

206 axis: int = 0, 

207 nodata: Union[np.ndarray, Number] = np.nan, 

208 method: str = 'linear', 

209 **kw): 

210 """Interpolate a 2D array along the given axis (based on scipy.interpolate.make_interp_spline). 

211 

212 :param data_2d: data to interpolate 

213 :param axis: axis to interpolate (0: along columns; 1: along rows) 

214 :param nodata: nodata array in the shape of data or nodata value 

215 :param method: interpolation method (‘linear’, ‘quadratic’, ‘cubic’) 

216 :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline 

217 :return: interpolated array 

218 """ 

219 if data_2d.ndim != 2: 

220 raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim) 

221 if axis > data_2d.ndim: 

222 raise ValueError("axis=%d is out of bounds for data with %d dimensions." % (axis, data_2d.ndim)) 

223 if method not in ['linear', 'quadratic', 'cubic']: 

224 raise ValueError(f"'{method}' is not a valid interpolation method. " 

225 f"Choose between 'linear', 'quadratic', and 'cubic'.") 

226 degree = 1 if method == 'linear' else 2 if method == 'quadratic' else 3 

227 

228 data_2d = data_2d if axis == 1 else data_2d.T 

229 

230 badmask_full = _get_baddata_mask(data_2d, nodata, transpose_inNodata=axis == 0) 

231 

232 # call 1D interpolation vectorized 

233 # => group the dataset by rows that have nodata at the same column position 

234 # => remember the row positions, call the interpolation for these rows at once (vectorized) 

235 # and substitute the original data at the previously grouped row positions 

236 groups_unique_rows = npi.group_by(badmask_full).split(np.arange(len(badmask_full))) 

237 

238 for indices_unique_rows in groups_unique_rows: 

239 badmask_grouped_rows = badmask_full[indices_unique_rows, :] 

240 

241 if np.any(badmask_grouped_rows[0, :]): 

242 badpos = np.argwhere(badmask_grouped_rows[0, :]).flatten() 

243 goodpos = np.delete(np.arange(data_2d.shape[1]), badpos) 

244 

245 if goodpos.size > 1: 

246 data_2d_grouped_rows = data_2d[indices_unique_rows] 

247 data_2d_grouped_rows[:, badpos] = \ 

248 make_interp_spline(goodpos, data_2d_grouped_rows[:, goodpos], axis=1, k=degree, **kw)(badpos) 

249 

250 data_2d[indices_unique_rows, :] = data_2d_grouped_rows 

251 

252 return data_2d if axis == 1 else data_2d.T 

253 

254 

255def interp_nodata_along_axis(data, 

256 axis=0, 

257 nodata: Union[np.ndarray, Number] = np.nan, 

258 method: str = 'linear', 

259 **kw): 

260 """Interpolate a 2D or 3D array along the given axis (based on scipy.interpolate.make_interp_spline). 

261 

262 :param data: data to interpolate 

263 :param axis: axis to interpolate (0: along columns; 1: along rows, 2: along bands) 

264 :param nodata: nodata array in the shape of data or nodata value 

265 :param method: interpolation method (‘linear’, 'quadratic', 'cubic') 

266 :param kw: keyword arguments to be passed to scipy.interpolate.make_interp_spline 

267 :return: interpolated array 

268 """ 

269 assert axis <= 2 

270 if data.ndim not in [2, 3]: 

271 raise ValueError('Expected a 2D or 3D array. Received a %dD array.' % data.ndim) 

272 if isinstance(nodata, np.ndarray) and nodata.shape != data.shape: 

273 raise ValueError('No-data mask and data must have the same shape.') 

274 

275 if data.ndim == 2: 

276 return interp_nodata_along_axis_2d(data, axis=axis, nodata=nodata, method=method, **kw) 

277 

278 else: 

279 def reshape_input(In): 

280 R, C, B = In.shape 

281 return \ 

282 In.reshape(C, R * B) if axis == 0 else \ 

283 np.transpose(In, axes=[1, 0, 2]).reshape(C, R * B).T if axis == 1 else \ 

284 In.reshape(R * C, B) 

285 

286 def reshape_output(out): 

287 return \ 

288 out.reshape(data.shape) if axis in [0, 2] else \ 

289 np.transpose(out.T.reshape(data.shape), axes=[1, 0, 2]) 

290 

291 return \ 

292 reshape_output( 

293 interp_nodata_along_axis_2d( 

294 data_2d=reshape_input(data), 

295 nodata=reshape_input(nodata) if isinstance(nodata, np.ndarray) else nodata, 

296 axis=axis if axis != 2 else 1, 

297 method=method, **kw)) 

298 

299 

300def interp_nodata_spatially_2d(data_2d: np.ndarray, 

301 axis: int = 0, 

302 nodata: Union[np.ndarray, Number] = np.nan, 

303 method: str = 'linear', 

304 fill_value: float = np.nan, 

305 implementation: str = 'pandas' 

306 ) -> np.ndarray: 

307 """Interpolate a 2D array spatially. 

308 

309 NOTE: If data_2d contains NaN values that are unlabelled by a given nodata array, 

310 they are also overwritten in the pandas implementation. 

311 

312 :param data_2d: data to interpolate 

313 :param axis: axis to interpolate (0: along columns; 1: along rows) 

314 :param nodata: nodata array in the shape of data or nodata value 

315 :param method: interpolation method 

316 - if implementation='scipy': ‘linear’, ‘nearest’, ‘cubic’ 

317 - if implementation='pandas': ‘linear’, ‘nearest’, 'slinear’, ‘quadratic’, ‘cubic’, etc. 

318 :param fill_value: value to fill into positions where no interpolation is possible 

319 :param implementation: 'scipy': interpolation based on scipy.interpolate.griddata 

320 'pandas': interpolation based on pandas.core.resample.Resampler.interpolate 

321 :return: interpolated array 

322 """ 

323 assert axis < 2 

324 if data_2d.ndim != 2: 

325 raise ValueError('Expected a 2D array. Received a %dD array.' % data_2d.ndim) 

326 

327 badmask_full = _get_baddata_mask(data_2d, nodata) 

328 

329 if badmask_full.any(): 

330 if implementation == 'scipy': 

331 if axis == 0: 

332 y, x = np.indices(data_2d.shape) 

333 else: 

334 x, y = np.indices(data_2d.shape) 

335 

336 data_2d[badmask_full] = \ 

337 griddata(np.array([x[~badmask_full], y[~badmask_full]]).T, # points we know 

338 data_2d[~badmask_full], # values we know 

339 np.array([x[badmask_full], y[badmask_full]]).T, # points to interpolate 

340 method=method, fill_value=fill_value) 

341 

342 elif implementation == 'pandas': 

343 data2int = data_2d.astype(float) 

344 data2int[badmask_full] = np.nan 

345 

346 data_2d = np.array(DataFrame(data2int) 

347 .interpolate(method=method, axis=axis)).astype(data_2d.dtype) 

348 

349 if np.isfinite(fill_value): 

350 mask_nan = np.isnan(data_2d) 

351 if True in mask_nan: 

352 data_2d[mask_nan] = fill_value 

353 

354 else: 

355 raise ValueError(implementation, 'Unknown implementation.') 

356 

357 return data_2d 

358 

359 

360def interp_nodata_spatially_3d(data_3d: np.ndarray, 

361 axis: int = 0, 

362 nodata: Union[np.ndarray, Number] = np.nan, 

363 method: str = 'linear', 

364 fill_value: float = np.nan, 

365 implementation: str = 'pandas', 

366 CPUs: int = None 

367 ) -> np.ndarray: 

368 """Interpolate a 3D array spatially, band-for-band. 

369 

370 :param data_3d: data to interpolate 

371 :param axis: axis to interpolate (0: along columns; 1: along rows) 

372 :param nodata: nodata array in the shape of data or nodata value 

373 :param method: interpolation method 

374 - if implementation='scipy': ‘linear’, ‘nearest’, ‘cubic’ 

375 - if implementation='pandas': ‘linear’, ‘nearest’, 'slinear’, ‘quadratic’, ‘cubic’, etc. 

376 :param fill_value: value to fill into positions where no interpolation is possible 

377 :param implementation: 'scipy': interpolation based on scipy.interpolate.griddata 

378 'pandas': interpolation based on pandas.core.resample.Resampler.interpolate 

379 :param CPUs: number of CPUs to use 

380 :return: interpolated array 

381 """ 

382 assert axis < 2 

383 

384 badmask_full = _get_baddata_mask(data_3d, nodata) 

385 

386 if CPUs > 1: 

387 with Pool(CPUs or cpu_count()) as pool: 

388 args = [[data_3d[:, :, band], axis, badmask_full[:, :, band], method, fill_value, implementation] 

389 for band in range(data_3d.shape[2])] 

390 results = pool.starmap(interp_nodata_spatially_2d, args) 

391 

392 pool.close() # needed for coverage to work in multiprocessing 

393 pool.join() 

394 

395 return np.dstack(results) 

396 

397 else: 

398 return \ 

399 np.dstack([interp_nodata_spatially_2d(data_3d[:, :, band], axis=axis, 

400 nodata=badmask_full[:, :, band], method=method, 

401 fill_value=fill_value, implementation=implementation) 

402 for band in range(data_3d.shape[2])])