import re
import functools
from collections import namedtuple
import numpy as np
from astropy import coordinates as coord
from astropy import units as u
from astropy.modeling import core, models, projections
from astropy.io import fits
from gwcs import coordinate_frames as cf
from gwcs import utils as gwutils
from gwcs.utils import sky_pairs, specsystems
from gwcs.wcs import WCS as gWCS
AffineMatrices = namedtuple("AffineMatrices", "matrix offset")
FrameMapping = namedtuple("FrameMapping", "cls description")
# Type of CoordinateFrame to construct for a FITS keyword and
# readable name so user knows what's going on
frame_mapping = {'WAVE': FrameMapping(cf.SpectralFrame, "Wavelength in vacuo"),
'AWAV': FrameMapping(cf.SpectralFrame, "Wavelength in air")}
#-----------------------------------------------------------------------------
# FITS-WCS -> gWCS
#-----------------------------------------------------------------------------
[docs]def fitswcs_to_gwcs(hdr):
"""
Create and return a gWCS object from a FITS header. If it can't
construct one, it should quietly return None.
"""
# coordinate names for CelestialFrame
coordinate_outputs = {'alpha_C', 'delta_C'}
# transform = gw.make_fitswcs_transform(hdr)
try:
transform = make_fitswcs_transform(hdr)
except Exception as e:
return None
outputs = transform.outputs
wcs_info = read_wcs_from_header(hdr)
naxes = transform.n_inputs
axes_names = ('x', 'y', 'z', 'u', 'v', 'w')[:naxes]
in_frame = cf.CoordinateFrame(naxes=naxes, axes_type=['SPATIAL'] * naxes,
axes_order=tuple(range(naxes)), name="pixels",
axes_names=axes_names, unit=[u.pix] * naxes)
out_frames = []
for i, output in enumerate(outputs):
unit_name = wcs_info["CUNIT"][i]
try:
unit = u.Unit(unit_name)
except TypeError:
unit = None
try:
frame_type = output[:4].upper()
frame_info = frame_mapping[frame_type]
except KeyError:
if output in coordinate_outputs:
continue
frame = cf.CoordinateFrame(naxes=1, axes_type=("SPATIAL",),
axes_order=(i,), unit=unit,
axes_names=(output,), name=output)
else:
frame = frame_info.cls(axes_order=(i,), unit=unit,
axes_names=(frame_type,),
name=frame_info.description)
out_frames.append(frame)
if coordinate_outputs.issubset(outputs):
frame_name = wcs_info["RADESYS"] # FK5, etc.
try:
ref_frame = getattr(coord, frame_name)()
axes_names = None
# TODO? Work out how to stuff EQUINOX and OBS-TIME into the frame
except (AttributeError, TypeError):
ref_frame = None
axes_names = ('lon', 'lat')
axes_order = (outputs.index('alpha_C'), outputs.index('delta_C'))
# Call it 'world' if there are no other axes, otherwise 'sky'
name = 'SKY' if len(outputs) > 2 else 'world'
cel_frame = cf.CelestialFrame(reference_frame=ref_frame, name=name,
axes_names=axes_names, axes_order=axes_order)
out_frames.append(cel_frame)
out_frame = (out_frames[0] if len(out_frames) == 1
else cf.CompositeFrame(out_frames, name='world'))
return gWCS([(in_frame, transform),
(out_frame, None)])
# -----------------------------------------------------------------------------
# gWCS -> FITS-WCS
# -----------------------------------------------------------------------------
[docs]def gwcs_to_fits(ndd, hdr=None):
"""
Convert a gWCS object to a collection of FITS WCS keyword/value pairs,
if possible. If the FITS WCS is only approximate, this should be indicated
with a dict entry {'FITS-WCS': 'APPROXIMATE'}. If there is no suitable
FITS representation, then a ValueError or NotImplementedError can be
raised.
Parameters
----------
ndd: NDData
The NDData whose wcs attribute we want converted
hdr: fits.Header
A Header object that may contain some useful keywords
Returns
-------
dict: values to insert into the FITS header to express this WCS
"""
if hdr is None:
hdr = {}
wcs = ndd.wcs
transform = wcs.forward_transform
world_axes = list(wcs.output_frame.axes_names)
nworld_axes = len(world_axes)
wcs_dict = {'WCSAXES': nworld_axes,
'WCSDIM': nworld_axes}
wcs_dict.update({f'CD{i+1}_{j+1}': 0. for j in range(nworld_axes)
for i in range(nworld_axes)})
pix_center = [0.5 * (length - 1) for length in ndd.shape[::-1]]
wcs_center = transform(*pix_center)
# Find and process the sky projection first
if {'lon', 'lat'}.issubset(world_axes):
if isinstance(wcs.output_frame, cf.CelestialFrame):
cel_frame = wcs.output_frame
elif isinstance(wcs.output_frame, cf.CompositeFrame):
for frame in wcs.output_frame.frames:
if isinstance(frame, cf.CelestialFrame):
cel_frame = frame
# TODO: Non-ecliptic coordinate frames
cel_ref_frame = cel_frame.reference_frame
if not isinstance(cel_ref_frame, coord.builtin_frames.BaseRADecFrame):
raise NotImplementedError("Cannot write non-ecliptic frames yet")
wcs_dict['RADESYS'] = cel_ref_frame.name.upper()
for m in transform:
if isinstance(m, models.RotateNative2Celestial):
nat2cel = m
if isinstance(m, models.Pix2SkyProjection):
m.name = 'pix2sky'
# Determine which sort of projection this is
for projcode in projections.projcodes:
if isinstance(m, getattr(models, f'Pix2Sky_{projcode}')):
break
else:
raise ValueError("Unknown projection class: {}".
format(m.__class__.__name__))
lon_axis = world_axes.index('lon')
lat_axis = world_axes.index('lat')
world_axes[lon_axis] = f'RA---{projcode}'
world_axes[lat_axis] = f'DEC--{projcode}'
wcs_dict[f'CRVAL{lon_axis+1}'] = nat2cel.lon.value
wcs_dict[f'CRVAL{lat_axis+1}'] = nat2cel.lat.value
# Remove projection parts so we can calculate the CD matrix
if projcode:
nat2cel.name = 'nat2cel'
transform = transform.replace_submodel('pix2sky', models.Identity(2))
transform = transform.replace_submodel('nat2cel', models.Identity(2))
# Deal with other axes
# TODO: AD should refactor to allow the descriptor to be used here
for i, axis_type in enumerate(wcs.output_frame.axes_type, start=1):
if f'CRVAL{i}' in wcs_dict:
continue
if axis_type == "SPECTRAL":
wcs_dict[f'CRVAL{i}'] = hdr.get('CENTWAVE', wcs_center[i-1] if nworld_axes > 1 else wcs_center)
wcs_dict[f'CTYPE{i}'] = wcs.output_frame.axes_names[i-1] # AWAV/WAVE
else: # Just something
wcs_dict[f'CRVAL{i}'] = 0
# Flag if we can't construct a perfect CD matrix
if not model_is_affine(transform):
wcs_dict['FITS-WCS'] = ('APPROXIMATE', 'FITS WCS is approximate')
affine = calculate_affine_matrices(transform, ndd.shape)
# Convert to x-first order
affine_matrix = affine.matrix[::-1, ::-1]
# Require an inverse to write out
if np.linalg.det(affine_matrix) == 0:
affine_matrix[-1, -1] = 1.
wcs_dict.update({f'CD{i+1}_{j+1}': affine_matrix[i, j]
for j, _ in enumerate(world_axes)
for i, _ in enumerate(world_axes)})
# Don't overwrite CTYPEi keywords we've already created
wcs_dict.update({f'CTYPE{i}': axis.upper()[:8]
for i, axis in enumerate(world_axes, start=1)
if f'CTYPE{i}' not in wcs_dict})
crval = [wcs_dict[f'CRVAL{i+1}'] for i, _ in enumerate(world_axes)]
crpix = np.array(wcs.backward_transform(*crval)) + 1
if nworld_axes == 1:
wcs_dict['CRPIX1'] = crpix
else:
# Comply with FITS standard, must define CRPIXj for "extra" axes
wcs_dict.update({f'CRPIX{j}': cpix for j, cpix in enumerate(np.concatenate([crpix, [0] * (nworld_axes-len(ndd.shape))]), start=1)})
for i, unit in enumerate(wcs.output_frame.unit, start=1):
try:
wcs_dict[f'CUNIT{i}'] = unit.name
except AttributeError:
pass
return wcs_dict
# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------
[docs]def model_is_affine(model):
""""
Test a Model for affinity. This is currently done by checking the
name of its class (or the class names of all its submodels)
TODO: Is this the right thing to do? We could compute the affine
matrices *assuming* affinity, and then check that a number of random
points behave as expected. Is that better?
"""
if isinstance(model, dict): # handle fix_inputs()
return True
try:
return np.logical_and.reduce([model_is_affine(m)
for m in model])
except TypeError:
# TODO: Delete "Const" one fix_inputs() broadcastingis fixed
return model.__class__.__name__[:5] in ('Affin', 'Rotat', 'Scale',
'Shift', 'Ident', 'Mappi',
'Const')
[docs]def calculate_affine_matrices(func, shape):
"""
Compute the matrix and offset necessary of an affine transform that
represents the supplied function. This is done by computing the
linear matrix along all axes extending from the centre of the region,
and then calculating the offset such that the transformation is
accurate at the centre of the region. The matrix and offset are returned
in the standard python order (i.e., y-first for 2D).
Parameters
----------
func: callable
function that maps input->output coordinates
shape: sequence
shape to use for fiducial points
Returns
-------
AffineMatrices(array, array): affine matrix and offset
"""
indim = len(shape)
try:
ndim = len(func(*shape)) # handle increase in number of axes
except TypeError:
ndim = 1
halfsize = [0.5 * length for length in shape] + [1.] * (ndim - indim)
points = np.array([halfsize] * (2 * ndim + 1)).T
points[:, 1:ndim + 1] += np.eye(ndim) * points[:, 0]
points[:, ndim + 1:] -= np.eye(ndim) * points[:, 0]
if ndim > 1:
transformed = np.array(list(zip(*list(func(*point[:indim])
for point in points.T)))).T
else:
transformed = np.array([func(*points)]).T
matrix = np.array([[0.5 * (transformed[j + 1, i] - transformed[ndim + j + 1, i]) / halfsize[j]
for j in range(ndim)] for i in range(ndim)])
offset = transformed[0] - np.dot(matrix, halfsize)
return AffineMatrices(matrix[::-1, ::-1], offset[::-1])
# -------------------------------------------------------------------------
# This stuff will hopefully all go into gwcs.utils
# -------------------------------------------------------------------------
[docs]def get_axes(header):
"""
Matches input with spectral and sky coordinate axes.
Parameters
----------
header : astropy.io.fits.Header or dict
FITS Header (or dict) with basic WCS information.
Returns
-------
sky_inmap, spectral_inmap, unknown : lists
indices in the output representing sky and spectral coordinates.
"""
if isinstance(header, fits.Header):
wcs_info = read_wcs_from_header(header)
elif isinstance(header, dict):
wcs_info = header
else:
raise TypeError("Expected a FITS Header or a dict.")
# Split each CTYPE value at "-" and take the first part.
# This should represent the coordinate system.
ctype = [ax.split('-')[0].upper() for ax in wcs_info['CTYPE']]
sky_inmap = []
spec_inmap = []
unknown = []
skysystems = np.array(list(sky_pairs.values())).flatten()
for ind, ax in enumerate(ctype):
if ax in specsystems:
spec_inmap.append(ind)
elif ax in skysystems:
sky_inmap.append(ind)
else:
unknown.append(ind)
if len(sky_inmap) == 1:
unknown.append(sky_inmap.pop())
if sky_inmap:
_is_skysys_consistent(ctype, sky_inmap)
return sky_inmap, spec_inmap, unknown
def _is_skysys_consistent(ctype, sky_inmap):
""" Determine if the sky axes in CTYPE match to form a standard celestial system."""
if len(sky_inmap) != 2:
raise ValueError("{} sky coordinate axes found. "
"There must be exactly 2".format(len(sky_inmap)))
for item in sky_pairs.values():
if ctype[sky_inmap[0]] == item[0]:
if ctype[sky_inmap[1]] != item[1]:
raise ValueError(
"Inconsistent ctype for sky coordinates {0} and {1}".format(*ctype))
break
elif ctype[sky_inmap[1]] == item[0]:
if ctype[sky_inmap[0]] != item[1]:
raise ValueError(
"Inconsistent ctype for sky coordinates {0} and {1}".format(*ctype))
sky_inmap.reverse()
break
def _get_contributing_axes(wcs_info, world_axes):
"""
Returns a tuple indicating which axes in the pixel frame make a
contribution to an axis or axes in the output frame.
Parameters
----------
wcs_info : dict
dict of WCS information
world_axes : int/iterable of ints
axes in the world coordinate system
Returns
-------
axes: list
axes whose pixel coordinates affect the output axis/axes
"""
cd = wcs_info['CD']
try:
return sorted(set(np.nonzero(cd[tuple(world_axes), :wcs_info['NAXIS']])[1]))
except TypeError: # world_axes is an int
return sorted(np.nonzero(cd[world_axes, :wcs_info['NAXIS']])[0])
#return sorted(set(j for j in range(wcs_info['NAXIS'])
# for i in world_axes if cd[i, j] != 0))
[docs]def fitswcs_image(header):
"""
Make a complete transform from CRPIX-shifted pixels to
sky coordinates from FITS WCS keywords. A Mapping is inserted
at the beginning, which may be removed later
Parameters
----------
header : astropy.io.fits.Header or dict
FITS Header or dict with basic FITS WCS keywords.
"""
if isinstance(header, fits.Header):
wcs_info = read_wcs_from_header(header)
elif isinstance(header, dict):
wcs_info = header
else:
raise TypeError("Expected a FITS Header or a dict.")
crpix = wcs_info['CRPIX']
cd = wcs_info['CD']
# get the part of the PC matrix corresponding to the imaging axes
sky_axes, spec_axes, unknown = get_axes(wcs_info)
if not sky_axes:
return
#if len(unknown) == 2:
# sky_axes = unknown
#else: # No sky here
# return
pixel_axes = _get_contributing_axes(wcs_info, sky_axes)
if len(pixel_axes) > 2:
raise ValueError("More than 2 pixel axes contribute to the sky coordinates")
translation_models = [models.Shift(-(crpix[i] - 1), name='crpix' + str(i + 1))
for i in pixel_axes]
translation = functools.reduce(lambda x, y: x & y, translation_models)
transforms = [translation]
# If only one axis is contributing to the sky (e.g., slit spectrum)
# then it must be that there's an extra axis in the CD matrix, so we
# create a "ghost" orthogonal axis here so an inverse can be defined
# Modify the CD matrix in case we have to use a backup Matrix Model later
if len(pixel_axes) == 1:
cd[sky_axes[0], -1] = -cd[sky_axes[1], pixel_axes[0]]
cd[sky_axes[1], -1] = cd[sky_axes[0], pixel_axes[0]]
sky_cd = cd[np.ix_(sky_axes, pixel_axes + [-1])]
affine = models.AffineTransformation2D(matrix=sky_cd, name='cd_matrix')
# TODO: replace when PR#10362 is in astropy
#rotation = models.fix_inputs(affine, {'y': 0})
rotation = models.Mapping((0, 0)) | models.Identity(1) & models.Const1D(0) | affine
rotation.inverse = affine.inverse | models.Mapping((0,), n_inputs=2)
else:
sky_cd = cd[np.ix_(sky_axes, pixel_axes)]
rotation = models.AffineTransformation2D(matrix=sky_cd, name='cd_matrix')
transforms.append(rotation)
projection = gwutils.fitswcs_nonlinear(wcs_info)
if projection:
transforms.append(projection)
sky_model = functools.reduce(lambda x, y: x | y, transforms)
sky_model.name = 'SKY'
sky_model.meta.update({'input_axes': pixel_axes,
'output_axes': sky_axes})
return sky_model
[docs]def fitswcs_linear(header):
"""
Create WCS linear transforms for any axes not associated with
celestial coordinates. We require that each world axis aligns
precisely with only a single pixel axis.
Parameters
----------
header : astropy.io.fits.Header or dict
FITS Header or dict with basic FITS WCS keywords.
"""
# We *always* want the wavelength solution model to be called "WAVE"
# even if the CTYPE keyword is "AWAV"
model_name_mapping = {"AWAV": "WAVE"}
if isinstance(header, fits.Header):
wcs_info = read_wcs_from_header(header)
elif isinstance(header, dict):
wcs_info = header
else:
raise TypeError("Expected a FITS Header or a dict.")
cd = wcs_info['CD']
crpix = wcs_info['CRPIX']
crval = wcs_info['CRVAL']
# get the part of the CD matrix corresponding to the imaging axes
sky_axes, spec_axes, unknown = get_axes(wcs_info)
#if not sky_axes and len(unknown) == 2:
# unknown = []
linear_models = []
for ax in spec_axes + unknown:
pixel_axes = _get_contributing_axes(wcs_info, ax)
if len(pixel_axes) == 1:
pixel_axis = pixel_axes[0]
linear_model = (models.Shift(1 - crpix[pixel_axis],
name='crpix' + str(pixel_axis + 1)) |
models.Scale(cd[ax, pixel_axis]) |
models.Shift(crval[ax]))
ctype = wcs_info['CTYPE'][ax][:4].upper()
linear_model.name = model_name_mapping.get(ctype, ctype)
linear_model.outputs = (wcs_info['CTYPE'][ax],)
linear_model.meta.update({'input_axes': pixel_axes,
'output_axes': [ax]})
linear_models.append(linear_model)
else:
raise ValueError(f"Axis {ax} depends on more than one input axis")
return linear_models