From 740268e7b4f02d0669cc53818595eae150f4a09d Mon Sep 17 00:00:00 2001 From: Gabriel Brammer Date: Thu, 4 Dec 2025 16:32:47 +0100 Subject: [PATCH 1/5] add get_axis_by_index function --- src/corner/core.py | 40 ++++++++++++++++++++++++++++++++++++++++ tests/test_corner.py | 27 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/src/corner/core.py b/src/corner/core.py index 1d8d5f0..4e564a6 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -892,6 +892,46 @@ def overplot_points(fig, xs, reverse=False, **kwargs): axes[k2, k1].plot(xs[k1], xs[k2], **kwargs) +def get_axis_by_index(fig, ix, iy): + """ + Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for + manually adding additional data or labels to a specific axis. + + Parameters + ---------- + fig : Figure + The figure generated by a call to :func:`corner.corner`. + + ix, iy : int + Indices of the parameter list corresponding to the plotted ``x`` and ``y`` axes. Only cases + where ``ix <= iy`` have plotted axes, and ``ix == iy`` corresponds to the histogram axis for + parameter index ``ix``. The function doesn't raise an error when ``ix > iy`` corresponding to one + of the hidden axes, though it does raise an error if either ``ix`` or ``iy`` is too large for the + dimensions of the plotted ``fig``. + + Returns + ------- + ax : axis + Entry in the ``fig.axes`` list. + """ + ndim = int(np.sqrt(len(fig.axes))) + if (ix > ndim - 1): + msg = f"ix={ix} too large for ndim={ndim}" + raise ValueError(msg) + elif (iy > ndim - 1): + msg = f"ix={ix} too large for ndim={ndim}" + raise ValueError(msg) + + for i in range(ndim**2): + ix_i = range(ndim)[(i % ndim)] + iy_i = range(ndim)[(i // ndim) - ndim] + + if (ix == ix_i) & (iy == iy_i): + break + + return fig.axes[i] + + def _parse_input(xs): xs = np.atleast_1d(xs) if len(xs.shape) == 1: diff --git a/tests/test_corner.py b/tests/test_corner.py index 78a7f96..1a91565 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -65,6 +65,33 @@ def test_basic(): _run_corner() +def test_axis_index(): + + labels = ["a", "b", "c"] + fig = _run_corner(labels=labels, n=100) + + # This should be x=a vs. y=c plotted in the lower left corner with both labels + ax = get_axis_by_index(fig, 0, 2) + assert(ax.get_xlabel() == labels[0]) + assert(ax.get_ylabel() == labels[2]) + + # This should be x=b vs. y=c, to the right of the previous with no y label + ax = get_axis_by_index(fig, 1, 2) + assert(ax.get_xlabel() == labels[1]) + assert(ax.get_ylabel() == "") + + # This should be the histogram of c at the lower right + ax = get_axis_by_index(fig, 2, 2) + + # Some big number, probably 1584 depending on the seed? + assert(ax.get_ylim()[1] > 100) + + # ix > iy is hidden, which have ranges set to (0,1) + ax = get_axis_by_index(fig, 2, 1) + assert np.allclose(ax.get_xlim(), [0,1]) + assert np.allclose(ax.get_ylim(), [0,1]) + + @image_comparison( baseline_images=["basic_log"], remove_text=True, extensions=["png"] ) From b778863a867d6b69743b444c895174348799d273 Mon Sep 17 00:00:00 2001 From: Gabriel Brammer Date: Thu, 4 Dec 2025 16:40:38 +0100 Subject: [PATCH 2/5] add function to init --- src/corner/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/corner/__init__.py b/src/corner/__init__.py index 2e8b2c4..790b8f4 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -2,6 +2,6 @@ __all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"] -from corner.core import hist2d, overplot_lines, overplot_points, quantile +from corner.core import hist2d, overplot_lines, overplot_points, get_axis_by_index, quantile from corner.corner import corner from corner.version import version as __version__ From d83843a3492f7123663707233ccc3e9675dd35b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:41:20 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/corner/__init__.py | 8 +++++++- src/corner/core.py | 10 +++++----- tests/test_corner.py | 14 +++++++------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/corner/__init__.py b/src/corner/__init__.py index 790b8f4..a4bff25 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -2,6 +2,12 @@ __all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"] -from corner.core import hist2d, overplot_lines, overplot_points, get_axis_by_index, quantile +from corner.core import ( + get_axis_by_index, + hist2d, + overplot_lines, + overplot_points, + quantile, +) from corner.corner import corner from corner.version import version as __version__ diff --git a/src/corner/core.py b/src/corner/core.py index 4e564a6..6bceb38 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -896,7 +896,7 @@ def get_axis_by_index(fig, ix, iy): """ Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for manually adding additional data or labels to a specific axis. - + Parameters ---------- fig : Figure @@ -915,20 +915,20 @@ def get_axis_by_index(fig, ix, iy): Entry in the ``fig.axes`` list. """ ndim = int(np.sqrt(len(fig.axes))) - if (ix > ndim - 1): + if ix > ndim - 1: msg = f"ix={ix} too large for ndim={ndim}" raise ValueError(msg) - elif (iy > ndim - 1): + elif iy > ndim - 1: msg = f"ix={ix} too large for ndim={ndim}" raise ValueError(msg) - + for i in range(ndim**2): ix_i = range(ndim)[(i % ndim)] iy_i = range(ndim)[(i // ndim) - ndim] if (ix == ix_i) & (iy == iy_i): break - + return fig.axes[i] diff --git a/tests/test_corner.py b/tests/test_corner.py index 1a91565..7557032 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -72,24 +72,24 @@ def test_axis_index(): # This should be x=a vs. y=c plotted in the lower left corner with both labels ax = get_axis_by_index(fig, 0, 2) - assert(ax.get_xlabel() == labels[0]) - assert(ax.get_ylabel() == labels[2]) + assert ax.get_xlabel() == labels[0] + assert ax.get_ylabel() == labels[2] # This should be x=b vs. y=c, to the right of the previous with no y label ax = get_axis_by_index(fig, 1, 2) - assert(ax.get_xlabel() == labels[1]) - assert(ax.get_ylabel() == "") + assert ax.get_xlabel() == labels[1] + assert ax.get_ylabel() == "" # This should be the histogram of c at the lower right ax = get_axis_by_index(fig, 2, 2) # Some big number, probably 1584 depending on the seed? - assert(ax.get_ylim()[1] > 100) + assert ax.get_ylim()[1] > 100 # ix > iy is hidden, which have ranges set to (0,1) ax = get_axis_by_index(fig, 2, 1) - assert np.allclose(ax.get_xlim(), [0,1]) - assert np.allclose(ax.get_ylim(), [0,1]) + assert np.allclose(ax.get_xlim(), [0, 1]) + assert np.allclose(ax.get_ylim(), [0, 1]) @image_comparison( From c6cb6cb45095a4d7cba45fbf5569074922b1550d Mon Sep 17 00:00:00 2001 From: Gabriel Brammer Date: Fri, 5 Dec 2025 11:01:51 +0100 Subject: [PATCH 4/5] add inverse function param_indices_from_axis and rename axis_from_param_indices --- src/corner/__init__.py | 11 ++++++++-- src/corner/core.py | 46 ++++++++++++++++++++++++++++++++++++++---- tests/test_corner.py | 20 ++++++++++++++---- 3 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/corner/__init__.py b/src/corner/__init__.py index 790b8f4..dc93cd7 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -1,7 +1,14 @@ # -*- coding: utf-8 -*- -__all__ = ["corner", "hist2d", "quantile", "overplot_lines", "overplot_points"] +__all__ = [ + "corner", "hist2d", "quantile", "overplot_lines", "overplot_points", + "axis_from_param_indices", + "param_indices_from_axis" +] + +from corner.core import ( + hist2d, overplot_lines, overplot_points, quantile, axis_from_param_indices, param_indices_from_axis +) -from corner.core import hist2d, overplot_lines, overplot_points, get_axis_by_index, quantile from corner.corner import corner from corner.version import version as __version__ diff --git a/src/corner/core.py b/src/corner/core.py index 4e564a6..2a80229 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -892,11 +892,12 @@ def overplot_points(fig, xs, reverse=False, **kwargs): axes[k2, k1].plot(xs[k1], xs[k2], **kwargs) -def get_axis_by_index(fig, ix, iy): +def axis_from_param_indices(fig, ix, iy, return_axis=True): """ Get axis corresponding to indices ``ix``, ``iy`` of the input data. This can be used, e.g., for - manually adding additional data or labels to a specific axis. - + manually adding additional data or labels to a specific axis. This is the inverse of + `param_indices_from_axis`. + Parameters ---------- fig : Figure @@ -909,6 +910,9 @@ def get_axis_by_index(fig, ix, iy): of the hidden axes, though it does raise an error if either ``ix`` or ``iy`` is too large for the dimensions of the plotted ``fig``. + return_axis : bool + Return either the axis itself or its integer index + Returns ------- ax : axis @@ -928,8 +932,42 @@ def get_axis_by_index(fig, ix, iy): if (ix == ix_i) & (iy == iy_i): break + + if return_axis: + return fig.axes[i] + else: + return i + + +def param_indices_from_axis(fig, i): + """ + Get indices ``ix``, ``iy`` of the input data associated with one of the plotted axes. This is the + inverse of `axis_from_param_indices`. + + Parameters + ---------- + fig : Figure + The figure generated by a call to :func:`corner.corner`. + + i : int + Index of an entry in the ``fig.axes`` list + + Returns + ------- + ix, iy : int + Indices of the figure axes list corresponding to the plotted ``x`` and ``y`` of the specified axis + index + """ + if (i > len(fig.axes)): + msg = f"{i} must be < len(fig.axes) = {len(fig.axes)}" + raise ValueError(msg) - return fig.axes[i] + ndim = int(np.sqrt(len(fig.axes))) + + ix = range(ndim)[(i % ndim)] + iy = range(ndim)[(i // ndim) - ndim] + + return ix, iy def _parse_input(xs): diff --git a/tests/test_corner.py b/tests/test_corner.py index 1a91565..46a7d13 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -71,26 +71,38 @@ def test_axis_index(): fig = _run_corner(labels=labels, n=100) # This should be x=a vs. y=c plotted in the lower left corner with both labels - ax = get_axis_by_index(fig, 0, 2) + ax = corner.axis_from_param_indices(fig, 0, 2) assert(ax.get_xlabel() == labels[0]) assert(ax.get_ylabel() == labels[2]) # This should be x=b vs. y=c, to the right of the previous with no y label - ax = get_axis_by_index(fig, 1, 2) + ax = corner.axis_from_param_indices(fig, 1, 2) assert(ax.get_xlabel() == labels[1]) assert(ax.get_ylabel() == "") # This should be the histogram of c at the lower right - ax = get_axis_by_index(fig, 2, 2) + ax = corner.axis_from_param_indices(fig, 2, 2) # Some big number, probably 1584 depending on the seed? assert(ax.get_ylim()[1] > 100) # ix > iy is hidden, which have ranges set to (0,1) - ax = get_axis_by_index(fig, 2, 1) + ax = corner.axis_from_param_indices(fig, 2, 1) assert np.allclose(ax.get_xlim(), [0,1]) assert np.allclose(ax.get_ylim(), [0,1]) + with pytest.raises(ValueError): + ax = corner.axis_from_param_indices(fig, 2, 4) + + # Inverse + for ix in range(len(labels)): + for iy in range(ix+1, len(labels)): + i = corner.axis_from_param_indices(fig, ix, iy, return_axis=False) + ix_i, iy_i = corner.param_indices_from_axis(fig, i) + assert np.allclose([ix_i, iy_i], [ix, iy]) + + with pytest.raises(ValueError): + _ = corner.param_indices_from_axis(fig, 100) @image_comparison( baseline_images=["basic_log"], remove_text=True, extensions=["png"] From c4d521486d4448e2b77470928d9f7097334f08fb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 5 Dec 2025 10:19:10 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/corner/__init__.py | 16 ++++++++++++---- src/corner/core.py | 4 ++-- tests/test_corner.py | 15 ++++++++------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/corner/__init__.py b/src/corner/__init__.py index dc93cd7..654cd8e 100644 --- a/src/corner/__init__.py +++ b/src/corner/__init__.py @@ -1,14 +1,22 @@ # -*- coding: utf-8 -*- __all__ = [ - "corner", "hist2d", "quantile", "overplot_lines", "overplot_points", + "corner", + "hist2d", + "quantile", + "overplot_lines", + "overplot_points", "axis_from_param_indices", - "param_indices_from_axis" + "param_indices_from_axis", ] from corner.core import ( - hist2d, overplot_lines, overplot_points, quantile, axis_from_param_indices, param_indices_from_axis + axis_from_param_indices, + hist2d, + overplot_lines, + overplot_points, + param_indices_from_axis, + quantile, ) - from corner.corner import corner from corner.version import version as __version__ diff --git a/src/corner/core.py b/src/corner/core.py index e79cdac..49aad67 100644 --- a/src/corner/core.py +++ b/src/corner/core.py @@ -958,10 +958,10 @@ def param_indices_from_axis(fig, i): Indices of the figure axes list corresponding to the plotted ``x`` and ``y`` of the specified axis index """ - if (i > len(fig.axes)): + if i > len(fig.axes): msg = f"{i} must be < len(fig.axes) = {len(fig.axes)}" raise ValueError(msg) - + ndim = int(np.sqrt(len(fig.axes))) ix = range(ndim)[(i % ndim)] diff --git a/tests/test_corner.py b/tests/test_corner.py index bd041ba..4057361 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -72,13 +72,13 @@ def test_axis_index(): # This should be x=a vs. y=c plotted in the lower left corner with both labels ax = corner.axis_from_param_indices(fig, 0, 2) - assert(ax.get_xlabel() == labels[0]) - assert(ax.get_ylabel() == labels[2]) + assert ax.get_xlabel() == labels[0] + assert ax.get_ylabel() == labels[2] # This should be x=b vs. y=c, to the right of the previous with no y label ax = corner.axis_from_param_indices(fig, 1, 2) - assert(ax.get_xlabel() == labels[1]) - assert(ax.get_ylabel() == "") + assert ax.get_xlabel() == labels[1] + assert ax.get_ylabel() == "" # This should be the histogram of c at the lower right ax = corner.axis_from_param_indices(fig, 2, 2) @@ -88,15 +88,15 @@ def test_axis_index(): # ix > iy is hidden, which have ranges set to (0,1) ax = corner.axis_from_param_indices(fig, 2, 1) - assert np.allclose(ax.get_xlim(), [0,1]) - assert np.allclose(ax.get_ylim(), [0,1]) + assert np.allclose(ax.get_xlim(), [0, 1]) + assert np.allclose(ax.get_ylim(), [0, 1]) with pytest.raises(ValueError): ax = corner.axis_from_param_indices(fig, 2, 4) # Inverse for ix in range(len(labels)): - for iy in range(ix+1, len(labels)): + for iy in range(ix + 1, len(labels)): i = corner.axis_from_param_indices(fig, ix, iy, return_axis=False) ix_i, iy_i = corner.param_indices_from_axis(fig, i) assert np.allclose([ix_i, iy_i], [ix, iy]) @@ -104,6 +104,7 @@ def test_axis_index(): with pytest.raises(ValueError): _ = corner.param_indices_from_axis(fig, 100) + @image_comparison( baseline_images=["basic_log"], remove_text=True, extensions=["png"] )