Source code for sunkit_dem.base_model
"""
Base model class for DEM models
"""
from abc import ABC, abstractmethod
import ndcube
import numpy as np
from ndcube.extra_coords.table_coord import MultipleTableCoordinate, QuantityTableCoordinate
from ndcube.wcs.wrappers import CompoundLowLevelWCS
import astropy.units as u
from astropy.nddata import StdDevUncertainty
__all__ = ["GenericModel"]
class BaseModel(ABC):
@abstractmethod
def _model(self):
raise NotImplementedError
def defines_model_for(self):
raise NotImplementedError
[docs]
class GenericModel(BaseModel):
"""
Base class for implementing a differential emission measure model
Parameters
----------
data : `~ndcube.NDCollection`
kernel : `dict`
`~astropy.units.Quantity` objects containing the kernels
of each response. The keys should correspond to those in ``data``.
temperature_bin_edges : `~astropy.units.Quantity`
Edges of the temperature bins in which the DEM is computed. The
rightmost edge is included. The kernel is evaluated at the bin centers.
The bin widths must be equal in log10.
"""
_registry = dict()
def __init_subclass__(cls, **kwargs):
"""
An __init_subclass__ hook initializes all of the subclasses of a given
class. So for each subclass, it will call this block of code on import.
This replicates some metaclass magic without the need to be aware of
metaclasses. Here we use this to register each subclass in a dict that
has the `defines_model_for` attribute. This is then passed into the Map
Factory so we can register them.
"""
super().__init_subclass__(**kwargs)
if hasattr(cls, 'defines_model_for'):
cls._registry[cls] = cls.defines_model_for
@u.quantity_input
def __init__(self, data, kernel, temperature_bin_edges: u.K, kernel_temperatures=None, **kwargs):
self.temperature_bin_edges = temperature_bin_edges
self.data = data
self.kernel_temperatures = kernel_temperatures
if self.kernel_temperatures is None:
self.kernel_temperatures = self.temperature_bin_centers
self.kernel = kernel
@property
def _keys(self):
# Internal reference for entries in kernel and data
# This ensures consistent ordering in kernel and data matrices
return sorted(list(self.kernel.keys()))
@property
@u.quantity_input
def temperature_bin_centers(self) -> u.K:
return (self.temperature_bin_edges[1:] + self.temperature_bin_edges[:-1])/2
@property
@u.quantity_input
def temperature_bin_widths(self) -> u.K:
return np.diff(self.temperature_bin_edges)
@property
def data(self) -> ndcube.NDCollection:
return self._data
@data.setter
def data(self, data):
"""
Check that input data is correctly formatted as an
`ndcube.NDCollection`
"""
if not isinstance(data, ndcube.NDCollection):
raise ValueError('Input data must be an NDCollection')
if not all([hasattr(data[k], 'unit') for k in data]):
raise u.UnitsError('Each NDCube in NDCollection must have units')
self._data = data
@property
def combined_mask(self):
"""
Combined mask of all members of ``data``. Will be True if any member is masked.
This is propagated to the final DEM result
"""
combined_mask = []
for k in self._keys:
if self.data[k].mask is not None:
combined_mask.append(self.data[k].mask)
else:
combined_mask.append(np.full(self.data[k].data.shape, False))
return np.any(combined_mask, axis=0)
@property
def kernel(self):
return self._kernel
@kernel.setter
def kernel(self, kernel):
if len(kernel) != len(self.data):
raise ValueError('Number of kernels must be equal to length of wavelength dimension.')
if not all([v.shape == self.kernel_temperatures.shape for _, v in kernel.items()]):
raise ValueError('Temperature bin centers and kernels must have the same shape.')
self._kernel = kernel
@property
def data_matrix(self):
return np.stack([self.data[k].data for k in self._keys])
@property
def uncertainty_matrix(self):
uncertainties = [self.data[k].uncertainty for k in self._keys]
if any([_u is None for _u in uncertainties]):
return None
return np.stack([_u.array for _u in uncertainties])
@property
def kernel_matrix(self):
return np.stack([self.kernel[k].value for k in self._keys])
[docs]
def fit(self, *args, **kwargs):
r"""
Apply inversion procedure to data.
Returns
-------
dem : `~ndcube.NDCube`
Differential emission measure as a function of temperature. The
temperature axis is evenly spaced in :math:`\log{T}`. The number
of dimensions depend on the input data.
"""
dem_dict = self._model(*args, **kwargs)
wcs = self._make_dem_wcs()
meta = self._make_dem_meta()
dem_data = dem_dict.pop('dem')
mask = np.full(dem_data.shape, False)
mask[:,...] = self.combined_mask
uncertainty = dem_dict.pop('uncertainty', None)
if uncertainty is not None:
uncertainty = StdDevUncertainty(uncertainty)
dem = ndcube.NDCube(dem_data,
wcs,
meta=meta,
mask=mask,
uncertainty=uncertainty)
cubes = [('dem', dem),]
for k in dem_dict:
cubes += [(k, ndcube.NDCube(dem_dict[k], wcs, meta=meta))]
return ndcube.NDCollection(cubes, )
def _make_dem_wcs(self):
data_wcs = self.data[self._keys[0]].wcs
temp_table = QuantityTableCoordinate(self.temperature_bin_centers,
names='temperature',
physical_types='phys.temperature')
temp_table_coord = MultipleTableCoordinate(temp_table)
mapping = list(range(data_wcs.pixel_n_dim))
mapping.extend([data_wcs.pixel_n_dim] * temp_table_coord.wcs.pixel_n_dim)
compound_wcs = CompoundLowLevelWCS(data_wcs, temp_table_coord.wcs, mapping=mapping)
return compound_wcs
def _make_dem_meta(self):
# Individual classes should override this if they want specific metadata
return {}