import numbers
import numpy as np
import pytest
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.nddata import CCDData, NDData
from astropy.table import Table
from astropy.visualization import (
AsymmetricPercentileInterval,
BaseInterval,
BaseStretch,
LogStretch,
ManualInterval,
)
from astropy.wcs import WCS
__all__ = ["ImageAPITest"]
DEFAULT_IMAGE_SHAPE = (100, 150)
[docs]
class ImageAPITest:
[docs]
@pytest.fixture
def data(self):
rng = np.random.default_rng(1234)
return rng.random(DEFAULT_IMAGE_SHAPE)
[docs]
@pytest.fixture
def wcs(self):
# This is a copy/paste from the astropy 4.3.1 documentation...
# Create a new WCS object. The number of axes must be set
# from the start
w = WCS(naxis=2)
# Set up an "Airy's zenithal" projection
# Note: WCS is 1-based, not 0-based
w.wcs.crpix = [-234.75, 8.3393]
w.wcs.cdelt = np.array([-0.066667, 0.066667])
w.wcs.crval = [0, -90]
w.wcs.ctype = ["RA---AIR", "DEC--AIR"]
w.wcs.set_pv([(2, 1, 45.0)])
return w
[docs]
@pytest.fixture
def catalog(self, wcs: WCS) -> Table:
"""
A catalog fixture that returns an empty table with the
expected columns.
"""
rng = np.random.default_rng(45328975)
x = rng.uniform(0, DEFAULT_IMAGE_SHAPE[0], size=10)
y = rng.uniform(0, DEFAULT_IMAGE_SHAPE[1], size=10)
coord = wcs.pixel_to_world(x, y)
cat = Table(
dict(
x=x,
y=y,
coord=coord,
)
)
return cat
# This setup is run before each test, ensuring that there are no
# side effects of one test on another
[docs]
@pytest.fixture(autouse=True)
def setup(self):
"""
Subclasses MUST define ``image_widget_class`` -- doing so as a
class variable does the trick.
"""
self.image = self.image_widget_class()
def _assert_empty_catalog_table(self, table):
assert isinstance(table, Table)
assert len(table) == 0
assert sorted(table.colnames) == sorted(["x", "y", "coord"])
def _get_catalog_labels_as_set(self):
marks = self.image.catalog_labels
return set(marks)
[docs]
@pytest.mark.parametrize("load_type", ["fits", "nddata", "array"])
def test_load(self, data, tmp_path, load_type):
match load_type:
case "fits":
hdu = fits.PrimaryHDU(data=data)
image_path = tmp_path / "test.fits"
hdu.header["BUNIT"] = "adu"
hdu.writeto(image_path)
load_arg = image_path
case "nddata":
load_arg = NDData(data=data)
case "array":
load_arg = data
self.image.load_image(load_arg)
[docs]
def test_set_get_center_xy(self, data):
self.image.load_image(data, image_label="test")
self.image.set_viewport(center=(10, 10), image_label="test") # X, Y
vport = self.image.get_viewport(image_label="test")
assert vport["center"] == (10, 10)
assert vport["image_label"] == "test"
[docs]
def test_set_get_center_world(self, data, wcs):
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
self.image.set_viewport(
center=SkyCoord(*wcs.wcs.crval, unit="deg"), image_label="test"
)
vport = self.image.get_viewport(image_label="test")
assert isinstance(vport["center"], SkyCoord)
assert vport["center"].ra.deg == pytest.approx(wcs.wcs.crval[0])
assert vport["center"].dec.deg == pytest.approx(wcs.wcs.crval[1])
[docs]
def test_set_get_fov_pixel(self, data):
# Set data first, since that is needed to determine zoom level
self.image.load_image(data, image_label="test")
self.image.set_viewport(fov=100, image_label="test")
vport = self.image.get_viewport(image_label="test")
assert vport["fov"] == 100
assert vport["image_label"] == "test"
[docs]
def test_set_get_fov_world(self, data, wcs):
# Set data first, since that is needed to determine zoom level
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
# Set the FOV in world coordinates
self.image.set_viewport(fov=0.1 * u.deg, image_label="test")
vport = self.image.get_viewport(image_label="test")
assert isinstance(vport["fov"], u.Quantity)
assert len(np.atleast_1d(vport["fov"])) == 1
assert vport["fov"].unit.physical_type == "angle"
fov_degree = vport["fov"].to(u.degree).value
assert fov_degree == pytest.approx(0.1)
[docs]
def test_set_get_viewport_errors(self, data, wcs):
# Test several of the expected errors that can be raised
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
# fov can be float or an angular Qunatity
with pytest.raises(u.UnitTypeError, match="[Ii]ncorrect unit for fov"):
self.image.set_viewport(fov=100 * u.meter, image_label="test")
# try an fov that is completely the wrong type
with pytest.raises(TypeError, match="[Ii]nvalid value for fov"):
self.image.set_viewport(fov="not a valid value", image_label="test")
# center can be a SkyCoord or a tuple of floats. Try a value that is neither
with pytest.raises(TypeError, match="[Ii]nvalid value for center"):
self.image.set_viewport(center="not a valid value", image_label="test")
# Check that an error is raised if a label is provided that does not
# match an image that is loaded.
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.set_viewport(
center=(10, 10), fov=100, image_label="not a valid label"
)
# Getting a viewport for an image_label that does not exist should
# raise an error
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_viewport(image_label="not a valid label")
# If there are multiple images loaded, the image_label must be provided
self.image.load_image(data, image_label="another test")
with pytest.raises(ValueError, match="Multiple image labels defined"):
self.image.get_viewport()
# setting sky_or_pixel to something other than 'sky' or 'pixel' or None
# should raise an error
with pytest.raises(ValueError, match="[Ss]ky_or_pixel must be"):
self.image.get_viewport(sky_or_pixel="not a valid value")
[docs]
def test_set_get_viewport_errors_because_no_wcs(self, data):
# Check that errors are raised when they should be when calling
# get_viewport when no WCS is present.
# Load the data without a WCS
self.image.load_image(data, image_label="test")
# Set the viewport with a SkyCoord center
with pytest.raises(TypeError, match="Center must be a tuple"):
self.image.set_viewport(
center=SkyCoord(ra=10, dec=20, unit="deg"), image_label="test"
)
# Set the viewport with a Quantity fov
with pytest.raises(TypeError, match="FOV must be a float"):
self.image.set_viewport(fov=100 * u.arcmin, image_label="test")
# Try getting the viewport as sky
with pytest.raises(ValueError, match="WCS is not set"):
self.image.get_viewport(image_label="test", sky_or_pixel="sky")
[docs]
@pytest.mark.parametrize("world", [True, False])
def test_viewport_is_defined_after_loading_image(self, tmp_path, data, wcs, world):
# Check that the viewport is set to a default value when an image
# is loaded, even if no viewport is explicitly set.
# Load the image from FITS to ensure that at least one image with WCS
# has been loaded from FITS.
wcs = wcs if world else None
ccd = CCDData(data=data, unit="adu", wcs=wcs)
ccd_path = tmp_path / "test.fits"
ccd.write(ccd_path)
self.image.load_image(ccd_path)
# Getting the viewport should not fail...
vport = self.image.get_viewport()
assert "center" in vport
assert "fov" in vport
assert "image_label" in vport
assert vport["image_label"] is None
if world:
assert isinstance(vport["center"], SkyCoord)
# fov should be a Quantity since WCS is present
assert isinstance(vport["fov"], u.Quantity)
else:
# No world, so center should be a tuple
assert isinstance(vport["center"], tuple)
# fov should be a float since no WCS
assert isinstance(vport["fov"], numbers.Real)
[docs]
def test_set_get_viewport_no_image_label(self, data):
# If there is only one image, the viewport should be able to be set
# and retrieved without an image label.
# Add an image without an image label
self.image.load_image(data)
# Set the viewport without an image label
self.image.set_viewport(center=(10, 10), fov=100)
# Getting the viewport again should return the same values
vport = self.image.get_viewport()
assert vport["center"] == (10, 10)
assert vport["fov"] == 100
assert vport["image_label"] is None
[docs]
def test_set_get_viewport_single_label(self, data):
# If there is only one image, the viewport should be able to be set
# and retrieved without an image label as long as the image
# has an image label.
# Add an image with an image label
self.image.load_image(data, image_label="test")
# Getting the viewport should not fail...
vport = self.image.get_viewport()
assert "center" in vport
assert "fov" in vport
assert "image_label" in vport
assert vport["image_label"] == "test"
# Set the viewport with an image label
self.image.set_viewport(center=(10, 10), fov=100)
# Getting the viewport again should return the same values
vport = self.image.get_viewport()
assert vport["center"] == (10, 10)
assert vport["fov"] == 100
assert vport["image_label"] == "test"
[docs]
def test_get_viewport_sky_or_pixel(self, data, wcs):
# Check that the viewport can be retrieved in both pixel and world
# coordinates, depending on the WCS of the image.
# Load the data with a WCS
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
input_center = SkyCoord(*wcs.wcs.crval, unit="deg")
input_fov = 2 * u.arcmin
self.image.set_viewport(center=input_center, fov=input_fov, image_label="test")
# Get the viewport in pixel coordinates
vport_pixel = self.image.get_viewport(image_label="test", sky_or_pixel="pixel")
# The WCS set up for the tests is 1-based, rather than the usual 0-based,
# so we need to subtract 1 from the pixel coordinates.
assert all(vport_pixel["center"] == (wcs.wcs.crpix - 1))
# tbh, not at all sure what the fov should be in pixel coordinates,
# so just check that it is a float.
assert isinstance(vport_pixel["fov"], numbers.Real)
# Get the viewport in world coordinates
vport_world = self.image.get_viewport(image_label="test", sky_or_pixel="sky")
assert vport_world["center"] == input_center
assert vport_world["fov"] == input_fov
[docs]
@pytest.mark.parametrize("sky_or_pixel", ["sky", "pixel"])
def test_get_viewport_no_sky_or_pixel(self, data, wcs, sky_or_pixel):
# Check that get_viewport returns the correct "default" sky_or_pixel
# value when the result ought to be unambiguous.
if sky_or_pixel == "sky":
use_wcs = wcs
else:
use_wcs = None
self.image.load_image(NDData(data=data, wcs=use_wcs), image_label="test")
vport = self.image.get_viewport(image_label="test")
match sky_or_pixel:
case "sky":
assert isinstance(vport["center"], SkyCoord)
assert vport["fov"].unit.physical_type == "angle"
case "pixel":
assert isinstance(vport["center"], tuple)
assert isinstance(vport["fov"], numbers.Real)
[docs]
def test_get_viewport_with_wcs_set_pixel_or_world(self, data, wcs):
# Check that the viewport can be retrieved in both pixel and world
# after setting with the opposite if the WCS is set.
# Load the data with a WCS
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
# Set the viewport in world coordinates
input_center = SkyCoord(*wcs.wcs.crval, unit="deg")
input_fov = 2 * u.arcmin
self.image.set_viewport(center=input_center, fov=input_fov, image_label="test")
# Get the viewport in pixel coordinates
vport_pixel = self.image.get_viewport(image_label="test", sky_or_pixel="pixel")
assert all(vport_pixel["center"] == (wcs.wcs.crpix - 1))
assert isinstance(vport_pixel["fov"], numbers.Real)
# Set the viewport in pixel coordinates
input_center_pixel = (wcs.wcs.crpix[0], wcs.wcs.crpix[1])
input_fov_pixel = 100 # in pixels
self.image.set_viewport(
center=input_center_pixel, fov=input_fov_pixel, image_label="test"
)
# Get the viewport in world coordinates
vport_world = self.image.get_viewport(image_label="test", sky_or_pixel="sky")
assert vport_world["center"] == wcs.pixel_to_world(*input_center_pixel)
assert isinstance(vport_world["fov"], u.Quantity)
[docs]
def test_viewport_round_trips(self, data, wcs):
# Check that the viewport retrieved with get can be used to set
# the viewport again, and that the values are the same.
self.image.load_image(NDData(data=data, wcs=wcs), image_label="test")
self.image.set_viewport(center=(10, 10), fov=100, image_label="test")
vport = self.image.get_viewport(image_label="test")
# Set the viewport again using the values from the get_viewport
self.image.set_viewport(**vport)
# Get the viewport again and check that the values are the same
vport2 = self.image.get_viewport(image_label="test")
assert vport2 == vport
[docs]
def test_set_catalog_style_before_catalog_data_raises_error(self):
# Make sure that adding a catalog style before adding any catalog
# data raises an error.
with pytest.raises(
ValueError, match="Must load a catalog before setting a catalog style"
):
self.image.set_catalog_style(color="red", shape="circle", size=10)
[docs]
def test_set_get_catalog_style_no_labels(self, catalog):
# Check that getting without setting returns a dict that contains
# the minimum required keys
required_style_keys = ["color", "shape", "size"]
marker_style = self.image.get_catalog_style()
for key in required_style_keys:
assert key in marker_style
# Add some data before setting a style
self.image.load_catalog(catalog)
# Check that setting a marker style works
marker_settings = dict(color="red", shape="crosshair", size=10)
self.image.set_catalog_style(**marker_settings.copy())
retrieved_style = self.image.get_catalog_style()
# Check that the marker style is set correctly
for key, value in marker_settings.items():
assert retrieved_style[key] == value
# Check that set accepts the output of get
self.image.set_catalog_style(**retrieved_style)
[docs]
def test_set_get_catalog_style_with_single_label(self, catalog):
# Check that when there is only a single catalog label it is
# not necessary to provide the label on get.
self.image.load_catalog(catalog, catalog_label="test1")
set_style_input = dict(
catalog_label="test1", color="blue", shape="square", size=5
)
self.image.set_catalog_style(**set_style_input.copy())
retrieved_style = self.image.get_catalog_style()
assert set_style_input == retrieved_style
[docs]
def test_get_catalog_style_with_multiple_labels_raises_error(self, catalog):
# Check that when there are multiple catalog labels, the
# get_catalog_style method raises an error if no label is given.
self.image.load_catalog(catalog, catalog_label="test1")
self.image.load_catalog(catalog, catalog_label="test2")
self.image.set_catalog_style(
catalog_label="test1", color="blue", shape="square", size=5
)
self.image.set_catalog_style(
catalog_label="test2", color="red", shape="circle", size=10
)
with pytest.raises(ValueError, match="Multiple catalog styles"):
self.image.get_catalog_style()
[docs]
def test_catalog_has_style_after_loading(self, catalog):
# Check that loading a catalog sets a default style for that catalog.
self.image.load_catalog(catalog, catalog_label="test1")
retrieved_style = self.image.get_catalog_style(catalog_label="test1")
assert isinstance(retrieved_style, dict)
assert "color" in retrieved_style
assert "shape" in retrieved_style
assert "size" in retrieved_style
# Loading again should have the same style
self.image.load_catalog(catalog, catalog_label="test1")
retrieved_style2 = self.image.get_catalog_style(catalog_label="test1")
assert retrieved_style2 == retrieved_style
[docs]
@pytest.mark.parametrize("catalog_label", ["test1", None])
def test_load_get_single_catalog_with_without_label(self, catalog, catalog_label):
# Make sure we can get a single catalog with or without a label.
self.image.load_catalog(
catalog,
x_colname="x",
y_colname="y",
skycoord_colname="coord",
catalog_label=catalog_label,
use_skycoord=False,
)
# Get the catalog without a label
retrieved_catalog = self.image.get_catalog()
assert (retrieved_catalog == catalog).all()
# Get the catalog with a label if there is one
if catalog_label is not None:
retrieved_catalog = self.image.get_catalog(catalog_label=catalog_label)
assert (retrieved_catalog == catalog).all()
[docs]
def test_load_multiple_catalogs(self, catalog):
# Load and get multiple catalogs
# Add a catalog
self.image.load_catalog(
catalog,
x_colname="x",
y_colname="y",
catalog_label="test1",
)
# Add the catalog again under different name.
self.image.load_catalog(
catalog,
x_colname="x",
y_colname="y",
catalog_label="test2",
)
assert sorted(self.image.catalog_labels) == ["test1", "test2"]
# No guarantee markers will come back in the same order, so sort them.
t1 = self.image.get_catalog(catalog_label="test1")
# Sort before comparing
t1.sort(["x", "y"])
catalog.sort(["x", "y"])
assert (t1["x"] == catalog["x"]).all()
assert (t1["y"] == catalog["y"]).all()
t2 = self.image.get_catalog(catalog_label="test2")
# Sort before comparing
t2.sort(["x", "y"])
assert (t2["x"] == catalog["x"]).all()
assert (t2["y"] == catalog["y"]).all()
# get_catalog without a label should fail with multiple catalogs
with pytest.raises(ValueError, match="Multiple catalog styles defined."):
self.image.get_catalog()
# if we remove one of the catalogs we should be able to get the
# other one without a label.
self.image.remove_catalog(catalog_label="test1")
# Make sure test1 is really gone.
assert self.image.catalog_labels == ("test2",)
# Get without a catalog
t2 = self.image.get_catalog()
# Sort before comparing
t2.sort(["x", "y"])
assert (t2["x"] == catalog["x"]).all()
assert (t2["y"] == catalog["y"]).all()
# Check that retrieving a marker set that doesn't exist returns
# an empty table with the right columns
tab = self.image.get_catalog(catalog_label="test1")
self._assert_empty_catalog_table(tab)
[docs]
def test_load_catalog_multiple_same_label(self, catalog):
# Check that loading a catalog with the same label multiple times
# does not raise an error and does not change the catalog.
self.image.load_catalog(catalog, catalog_label="test1")
self.image.load_catalog(catalog, catalog_label="test1")
retrieved_catalog = self.image.get_catalog(catalog_label="test1")
assert len(retrieved_catalog) == len(catalog)
[docs]
def test_load_catalog_with_skycoord_no_wcs(self, catalog, data):
# Check that loading a catalog with skycoord but no x/y and
# no WCS returns a catalog with None for x and y.
self.image.load_image(data)
# Remove x/y columns from the catalog
del catalog["x", "y"]
with pytest.raises(
ValueError, match="Cannot use pixel coordinates without pixel columns"
):
self.image.load_catalog(catalog)
[docs]
def test_load_catalog_with_use_skycoord_no_skycoord_no_wcs(self, catalog, data):
# Check that loading a catalog with use_skycoord=True but no
# skycoord column and no WCS raises an error.
self.image.load_image(data)
del catalog["coord"] # Remove the skycoord column
with pytest.raises(ValueError, match="Cannot use sky coordinates without"):
self.image.load_catalog(catalog, use_skycoord=True)
[docs]
def test_load_catalog_with_xy_and_wcs(self, catalog, data, wcs):
# Check that loading a catalog that wants to use sky coordinates,
# has no coordinate column but has x/y and a WCS works.
self.image.load_image(NDData(data=data, wcs=wcs))
# Remove the skycoord column from the catalog
del catalog["coord"]
# Add the catalog with x/y and WCS
self.image.load_catalog(catalog, use_skycoord=True)
# Retrieve the catalog and check that the x and y columns are there
retrieved_catalog = self.image.get_catalog()
assert "x" in retrieved_catalog.colnames
assert "y" in retrieved_catalog.colnames
assert "coord" in retrieved_catalog.colnames
# Check that the coordinates are correct
coords = wcs.pixel_to_world(catalog["x"], catalog["y"])
assert all(coords.separation(retrieved_catalog["coord"]) < 1e-9 * u.deg)
[docs]
def test_catalog_info_preserved_after_load(self, catalog):
# Make sure that any catalog columns in addition to the position data
# is preserved after loading a catalog.
# Add a column with some extra information
catalog["extra_info"] = np.arange(len(catalog))
self.image.load_catalog(catalog, catalog_label="test1")
# Retrieve the catalog and check that the extra column is there
retrieved_catalog = self.image.get_catalog(catalog_label="test1")
assert "extra_info" in retrieved_catalog.colnames
assert (retrieved_catalog["extra_info"] == catalog["extra_info"]).all()
[docs]
def test_load_catalog_with_no_style_has_a_style(self, catalog):
# Check that loading a catalog without a style sets a default style
# for that catalog.
self.image.load_catalog(catalog, catalog_label="test1")
retrieved_style = self.image.get_catalog_style(catalog_label="test1")
assert isinstance(retrieved_style, dict)
assert "color" in retrieved_style
assert "shape" in retrieved_style
assert "size" in retrieved_style
[docs]
def test_load_catalog_with_style_sets_style(self, catalog):
# Check that loading a catalog with a style sets the style
# for that catalog.
style = dict(color="blue", shape="square", size=10)
self.image.load_catalog(
catalog, catalog_label="test1", catalog_style=style.copy()
)
retrieved_style = self.image.get_catalog_style(catalog_label="test1")
# Add catalog_label to the style for comparison
style["catalog_label"] = "test1"
assert retrieved_style == style
[docs]
def test_remove_catalog(self):
with pytest.raises(ValueError, match="arf"):
self.image.remove_catalog(catalog_label="arf")
[docs]
def test_remove_catalogs_name_all(self):
data = np.arange(10).reshape(5, 2)
tab = Table(data=data, names=["x", "y"])
self.image.load_catalog(tab, catalog_label="test1", use_skycoord=False)
self.image.load_catalog(tab, catalog_label="test2", use_skycoord=False)
self.image.remove_catalog(catalog_label="*")
self._assert_empty_catalog_table(self.image.get_catalog())
[docs]
def test_remove_catalog_does_not_accept_list(self):
data = np.arange(10).reshape(5, 2)
tab = Table(data=data, names=["x", "y"])
self.image.load_catalog(tab, catalog_label="test1", use_skycoord=False)
self.image.load_catalog(tab, catalog_label="test2", use_skycoord=False)
with pytest.raises(
TypeError, match="Cannot remove multiple catalogs from a list"
):
self.image.remove_catalog(catalog_label=["test1", "test2"])
[docs]
def test_adding_catalog_as_world(self, data, wcs):
ndd = NDData(data=data, wcs=wcs)
self.image.load_image(ndd)
# Add markers using world coordinates
pixels = np.linspace(0, 100, num=10).reshape(5, 2)
marks_pix = Table(data=pixels, names=["x", "y"], dtype=("float", "float"))
marks_coords = wcs.pixel_to_world(marks_pix["x"], marks_pix["y"])
mark_coord_table = Table(data=[marks_coords], names=["coord"])
self.image.load_catalog(mark_coord_table, use_skycoord=True)
result = self.image.get_catalog()
# Check the x, y positions as long as we are testing things...
# The first test had one entry that was zero, so any check
# based on rtol will not work. Added a small atol to make sure
# the test passes.
np.testing.assert_allclose(result["x"], marks_pix["x"], atol=1e-9)
np.testing.assert_allclose(result["y"], marks_pix["y"])
np.testing.assert_allclose(
result["coord"].ra.deg, mark_coord_table["coord"].ra.deg
)
np.testing.assert_allclose(
result["coord"].dec.deg, mark_coord_table["coord"].dec.deg
)
[docs]
def test_stretch(self):
original_stretch = self.image.get_stretch()
with pytest.raises(TypeError, match=r"Stretch.*not valid.*"):
self.image.set_stretch("not a valid value")
# A bad value should leave the stretch unchanged
assert self.image.get_stretch() is original_stretch
self.image.set_stretch(LogStretch())
# A valid value should change the stretch
assert self.image.get_stretch() is not original_stretch
assert isinstance(self.image.get_stretch(), LogStretch)
[docs]
def test_cuts(self, data):
with pytest.raises(TypeError, match="[mM]ust be"):
self.image.set_cuts("not a valid value")
with pytest.raises(TypeError, match="[mM]ust be"):
self.image.set_cuts((1, 10, 100))
# Setting using histogram requires data
self.image.load_image(data)
self.image.set_cuts(AsymmetricPercentileInterval(0.1, 99.9))
assert isinstance(self.image.get_cuts(), AsymmetricPercentileInterval)
self.image.set_cuts((10, 100))
assert isinstance(self.image.get_cuts(), ManualInterval)
assert self.image.get_cuts().get_limits(data) == (10, 100)
[docs]
def test_stretch_cuts_labels(self, data):
# Check that stretch and cuts can be set with labels
self.image.load_image(data, image_label="test")
# Set stretch and cuts with labels
self.image.set_stretch(LogStretch(), image_label="test")
self.image.set_cuts((10, 100), image_label="test")
# Get stretch and cuts with labels
stretch = self.image.get_stretch(image_label="test")
cuts = self.image.get_cuts(image_label="test")
assert isinstance(stretch, LogStretch)
assert isinstance(cuts, ManualInterval)
assert cuts.get_limits(data) == (10, 100)
[docs]
def test_stretch_cuts_are_set_after_loading_image(self, data):
# Check that stretch and cuts are set to default values after loading an image
self.image.load_image(data, image_label="test")
stretch = self.image.get_stretch(image_label="test")
cuts = self.image.get_cuts(image_label="test")
# Backends can set whatever stretch and cuts they want, so
# we just check that they are instances of the expected classes.
assert isinstance(stretch, BaseStretch)
assert isinstance(cuts, BaseInterval)
[docs]
def test_stretch_cuts_errors(self, data):
# Check that errors are raised when trying to get or set stretch or cuts
# for an image label that does not exist.
self.image.load_image(data, image_label="test")
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_stretch(image_label="not a valid label")
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_cuts(image_label="not a valid label")
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.set_stretch(LogStretch(), image_label="not a valid label")
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.set_cuts((10, 100), image_label="not a valid label")
[docs]
def test_set_get_colormap(self, data):
# Check setting and getting with a single image label.
self.image.load_image(data, image_label="test")
cmap_desired = "gray"
self.image.set_colormap(cmap_desired)
assert self.image.get_colormap() == cmap_desired
# Check that the colormap can be set with an image label
new_cmap = "viridis"
self.image.set_colormap(new_cmap, image_label="test")
assert self.image.get_colormap(image_label="test") == new_cmap
[docs]
def test_set_colormap_errors(self, data):
# Check that setting a colormap raises an error if the colormap
# is not in the list of allowed colormaps.
self.image.load_image(data, image_label="test")
# Check that getting a colormap for an image label that does not exist
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_colormap(image_label="not a valid label")
# Check that setting a colormap without an image label fails
# when there is more than one image label
self.image.load_image(data, image_label="another test")
with pytest.raises(ValueError, match="Multiple image labels defined"):
self.image.set_colormap("gray")
# Same for getting the colormap without an image label
with pytest.raises(ValueError, match="Multiple image labels defined"):
self.image.get_colormap()
[docs]
def test_save(self, tmp_path):
filename = tmp_path / "woot.png"
self.image.save(filename)
assert filename.is_file()
[docs]
def test_save_overwrite(self, tmp_path):
filename = tmp_path / "woot.png"
# First write should be fine
self.image.save(filename)
assert filename.is_file()
# Second write should raise an error because file exists
with pytest.raises(FileExistsError):
self.image.save(filename)
# Using overwrite should save successfully
self.image.save(filename, overwrite=True)
[docs]
def test_image_labels(self, data):
# the test viewer begins with a default empty image
assert len(self.image.image_labels) == 0
assert isinstance(self.image.image_labels, tuple)
self.image.load_image(data, image_label="test")
assert len(self.image.image_labels) == 1
assert self.image.image_labels[-1] == "test"
[docs]
def test_get_image(self, data):
self.image.load_image(data, image_label="test")
# currently the type is not specified in the API
assert self.image.get_image() is not None
assert self.image.get_image(image_label="test") is not None
retrieved_image = self.image.get_image(image_label="test")
self.image.load_image(retrieved_image, image_label="another test")
assert self.image.get_image(image_label="another test") is not None
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_image(image_label="not a valid label")
[docs]
def test_all_methods_accept_additional_kwargs(self, data, catalog, tmp_path):
"""
Make sure all methods accept additional keyword arguments
that are not defined in the protocol.
"""
from astro_image_display_api import ImageViewerInterface
all_methods_and_attributes = ImageViewerInterface.__protocol_attrs__
all_methods = [
method
for method in all_methods_and_attributes
if callable(getattr(self.image, method))
]
# Make a small dictionary keys that are random characters
additional_kwargs = {k: f"value{k}" for k in ["fsda", "urioeh", "m898h]"]}
# Make a dictionary of the required arguments for any methods that have required
# argument
required_args = dict(
load_image=data,
set_cuts=(10, 100),
set_stretch=LogStretch(),
set_colormap="viridis",
save=tmp_path / "test.png",
load_catalog=catalog,
)
failed_methods = []
# Take out the loading methods because they must happen first and take out
# remove_catalog because it must happen last.
all_methods = list(
set(all_methods) - set(["load_image", "load_catalog", "remove_catalog"])
)
# Load an image and a catalog first since other methods require these
# have been done
try:
self.image.load_image(required_args["load_image"], **additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append("load_image")
else:
raise e
try:
self.image.load_catalog(required_args["load_catalog"], **additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append("load_catalog")
else:
raise e
if not failed_methods:
# No point in running some of these if setting image or catalog has failed.
# Run remove_catalog last so that it does not interfere with the
# other methods that require an image or catalog to be loaded.
for method in all_methods + ["remove_catalog"]:
# Call each method with the required arguments and additional kwargs
# Accumulate the failures and report them at the end
try:
if method in required_args:
# If the method has required arguments, call it with those
getattr(self.image, method)(
required_args[method], **additional_kwargs
)
else:
# If the method does not have required arguments, just call it
# with additional kwargs
getattr(self.image, method)(**additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append(method)
else:
raise e
else:
failed_methods.append(
"No other methods were tested because the ones above failed."
)
assert not failed_methods, (
"The following methods failed when called with additional kwargs:\n\t"
f"{'\n\t'.join(failed_methods)}"
)
[docs]
def test_every_method_attribute_has_docstring(self):
"""
Check that every method and attribute in the protocol has a docstring.
"""
from astro_image_display_api import ImageViewerInterface
all_methods_and_attributes = ImageViewerInterface.__protocol_attrs__
method_attrs_no_docs = []
for method in all_methods_and_attributes:
attr = getattr(self.image, method)
# Make list of methods and attributes that have no docstring
# and assert that list is empty at the end of the test.
if not attr.__doc__:
method_attrs_no_docs.append(method)
assert not method_attrs_no_docs, (
"The following methods and attributes have no docstring:\n\t"
f"{'\n\t'.join(method_attrs_no_docs)}"
)