Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/corner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
# -*- coding: utf-8 -*-

__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"]
__all__ = [
"corner",
"hist2d",
"quantile",
"overplot_lines",
"overplot_points",
"axis_from_param_indices",
"param_indices_from_axis",
]

from corner.core import hist2d, overplot_lines, overplot_points, quantile
from corner.core import (
axis_from_param_indices,
hist2d,
overplot_lines,
overplot_points,
param_indices_from_axis,
quantile,
)
from corner.corner import corner
from corner.version import version as __version__
78 changes: 78 additions & 0 deletions src/corner/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,84 @@ def overplot_points(fig, xs, reverse=False, **kwargs):
axes[k2, k1].plot(xs[k1], xs[k2], **kwargs)


def axis_from_param_indices(fig, ix, iy, return_axis=True):
"""
Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for
manually adding additional data or labels to a specific axis. This is the inverse of
`param_indices_from_axis`.

Parameters
----------
fig : Figure
The figure generated by a call to :func:`corner.corner`.

ix, iy : int
Indices of the parameter list corresponding to the plotted ``x`` and ``y`` axes. Only cases
where ``ix <= iy`` have plotted axes, and ``ix == iy`` corresponds to the histogram axis for
parameter index ``ix``. The function doesn't raise an error when ``ix > iy`` corresponding to one
of the hidden axes, though it does raise an error if either ``ix`` or ``iy`` is too large for the
dimensions of the plotted ``fig``.

return_axis : bool
Return either the axis itself or its integer index

Returns
-------
ax : axis
Entry in the ``fig.axes`` list.
"""
ndim = int(np.sqrt(len(fig.axes)))
if ix > ndim - 1:
msg = f"ix={ix} too large for ndim={ndim}"
raise ValueError(msg)
elif iy > ndim - 1:
msg = f"ix={ix} too large for ndim={ndim}"
raise ValueError(msg)

for i in range(ndim**2):
ix_i = range(ndim)[(i % ndim)]
iy_i = range(ndim)[(i // ndim) - ndim]

if (ix == ix_i) & (iy == iy_i):
break

if return_axis:
return fig.axes[i]
else:
return i


def param_indices_from_axis(fig, i):
"""
Get indices ``ix``, ``iy`` of the input data associated with one of the plotted axes. This is the
inverse of `axis_from_param_indices`.

Parameters
----------
fig : Figure
The figure generated by a call to :func:`corner.corner`.

i : int
Index of an entry in the ``fig.axes`` list

Returns
-------
ix, iy : int
Indices of the figure axes list corresponding to the plotted ``x`` and ``y`` of the specified axis
index
"""
if i > len(fig.axes):
msg = f"{i} must be < len(fig.axes) = {len(fig.axes)}"
raise ValueError(msg)

ndim = int(np.sqrt(len(fig.axes)))

ix = range(ndim)[(i % ndim)]
iy = range(ndim)[(i // ndim) - ndim]

return ix, iy


def _parse_input(xs):
xs = np.atleast_1d(xs)
if len(xs.shape) == 1:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,46 @@ def test_basic():
_run_corner()


def test_axis_index():

labels = ["a", "b", "c"]
fig = _run_corner(labels=labels, n=100)

# This should be x=a vs. y=c plotted in the lower left corner with both labels
ax = corner.axis_from_param_indices(fig, 0, 2)
assert ax.get_xlabel() == labels[0]
assert ax.get_ylabel() == labels[2]

# This should be x=b vs. y=c, to the right of the previous with no y label
ax = corner.axis_from_param_indices(fig, 1, 2)
assert ax.get_xlabel() == labels[1]
assert ax.get_ylabel() == ""

# This should be the histogram of c at the lower right
ax = corner.axis_from_param_indices(fig, 2, 2)

# Some big number, probably 1584 depending on the seed?
assert ax.get_ylim()[1] > 100

# ix > iy is hidden, which have ranges set to (0,1)
ax = corner.axis_from_param_indices(fig, 2, 1)
assert np.allclose(ax.get_xlim(), [0, 1])
assert np.allclose(ax.get_ylim(), [0, 1])

with pytest.raises(ValueError):
ax = corner.axis_from_param_indices(fig, 2, 4)

# Inverse
for ix in range(len(labels)):
for iy in range(ix + 1, len(labels)):
i = corner.axis_from_param_indices(fig, ix, iy, return_axis=False)
ix_i, iy_i = corner.param_indices_from_axis(fig, i)
assert np.allclose([ix_i, iy_i], [ix, iy])

with pytest.raises(ValueError):
_ = corner.param_indices_from_axis(fig, 100)


@image_comparison(
baseline_images=["basic_log"], remove_text=True, extensions=["png"]
)
Expand Down