diff --git a/src/corner/__init__.py b/src/corner/__init__.py index 2e8b2c4..654cd8e 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -1,7 +1,22 @@ # -*- coding: utf-8 -*- -__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"] +__all__ = [ + "corner", + "hist2d", + "quantile", + "overplot_lines", + "overplot_points", + "axis_from_param_indices", + "param_indices_from_axis", +] -from corner.core import hist2d, overplot_lines, overplot_points, quantile +from corner.core import ( + axis_from_param_indices, + hist2d, + overplot_lines, + overplot_points, + param_indices_from_axis, + quantile, +) from corner.corner import corner from corner.version import version as __version__ diff --git a/src/corner/core.py b/src/corner/core.py index 1d8d5f0..49aad67 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -892,6 +892,84 @@ def overplot_points(fig, xs, reverse=False, **kwargs): axes[k2, k1].plot(xs[k1], xs[k2], **kwargs) +def axis_from_param_indices(fig, ix, iy, return_axis=True): + """ + Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for + manually adding additional data or labels to a specific axis. This is the inverse of + `param_indices_from_axis`. + + Parameters + ---------- + fig : Figure + The figure generated by a call to :func:`corner.corner`. + + ix, iy : int + Indices of the parameter list corresponding to the plotted ``x`` and ``y`` axes. Only cases + where ``ix <= iy`` have plotted axes, and ``ix == iy`` corresponds to the histogram axis for + parameter index ``ix``. The function doesn't raise an error when ``ix > iy`` corresponding to one + of the hidden axes, though it does raise an error if either ``ix`` or ``iy`` is too large for the + dimensions of the plotted ``fig``. + + return_axis : bool + Return either the axis itself or its integer index + + Returns + ------- + ax : axis + Entry in the ``fig.axes`` list. + """ + ndim = int(np.sqrt(len(fig.axes))) + if ix > ndim - 1: + msg = f"ix={ix} too large for ndim={ndim}" + raise ValueError(msg) + elif iy > ndim - 1: + msg = f"ix={ix} too large for ndim={ndim}" + raise ValueError(msg) + + for i in range(ndim**2): + ix_i = range(ndim)[(i % ndim)] + iy_i = range(ndim)[(i // ndim) - ndim] + + if (ix == ix_i) & (iy == iy_i): + break + + if return_axis: + return fig.axes[i] + else: + return i + + +def param_indices_from_axis(fig, i): + """ + Get indices ``ix``, ``iy`` of the input data associated with one of the plotted axes. This is the + inverse of `axis_from_param_indices`. + + Parameters + ---------- + fig : Figure + The figure generated by a call to :func:`corner.corner`. + + i : int + Index of an entry in the ``fig.axes`` list + + Returns + ------- + ix, iy : int + Indices of the figure axes list corresponding to the plotted ``x`` and ``y`` of the specified axis + index + """ + if i > len(fig.axes): + msg = f"{i} must be < len(fig.axes) = {len(fig.axes)}" + raise ValueError(msg) + + ndim = int(np.sqrt(len(fig.axes))) + + ix = range(ndim)[(i % ndim)] + iy = range(ndim)[(i // ndim) - ndim] + + return ix, iy + + def _parse_input(xs): xs = np.atleast_1d(xs) if len(xs.shape) == 1: diff --git a/tests/test_corner.py b/tests/test_corner.py index 78a7f96..4057361 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -65,6 +65,46 @@ def test_basic(): _run_corner() +def test_axis_index(): + + labels = ["a", "b", "c"] + fig = _run_corner(labels=labels, n=100) + + # This should be x=a vs. y=c plotted in the lower left corner with both labels + ax = corner.axis_from_param_indices(fig, 0, 2) + assert ax.get_xlabel() == labels[0] + assert ax.get_ylabel() == labels[2] + + # This should be x=b vs. y=c, to the right of the previous with no y label + ax = corner.axis_from_param_indices(fig, 1, 2) + assert ax.get_xlabel() == labels[1] + assert ax.get_ylabel() == "" + + # This should be the histogram of c at the lower right + ax = corner.axis_from_param_indices(fig, 2, 2) + + # Some big number, probably 1584 depending on the seed? + assert ax.get_ylim()[1] > 100 + + # ix > iy is hidden, which have ranges set to (0,1) + ax = corner.axis_from_param_indices(fig, 2, 1) + assert np.allclose(ax.get_xlim(), [0, 1]) + assert np.allclose(ax.get_ylim(), [0, 1]) + + with pytest.raises(ValueError): + ax = corner.axis_from_param_indices(fig, 2, 4) + + # Inverse + for ix in range(len(labels)): + for iy in range(ix + 1, len(labels)): + i = corner.axis_from_param_indices(fig, ix, iy, return_axis=False) + ix_i, iy_i = corner.param_indices_from_axis(fig, i) + assert np.allclose([ix_i, iy_i], [ix, iy]) + + with pytest.raises(ValueError): + _ = corner.param_indices_from_axis(fig, 100) + + @image_comparison( baseline_images=["basic_log"], remove_text=True, extensions=["png"] )