Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/spatialdata_plot/pl/_datashader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,21 @@ def _build_datashader_color_key(
) -> dict[str, str]:
"""Build a datashader ``color_key`` dict from a categorical series and its color vector."""
na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex
colors_arr = np.asarray(color_vector, dtype=object)
categories = np.asarray(cat_series.categories, dtype=str)
codes = np.asarray(cat_series.codes)

# Use np.unique to find the first occurrence of each category in one pass,
# avoiding a Python loop over all points. See #379.
unique_codes, first_indices = np.unique(codes, return_index=True)

first_color: dict[str, str] = {}
for code, color in zip(cat_series.codes, colors_arr, strict=False):
for code, idx in zip(unique_codes, first_indices, strict=True):
if code < 0:
continue
cat_name = str(cat_series.categories[code])
if cat_name not in first_color:
first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color
return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories}
c = color_vector[idx]
first_color[categories[code]] = _hex_no_alpha(c) if isinstance(c, str) and c.startswith("#") else c

return {cat: first_color.get(cat, na_hex) for cat in categories}


def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series:
Expand Down
25 changes: 14 additions & 11 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from spatialdata_plot.pl.utils import (
_ax_show_and_transform,
_convert_shapes,
_datashader_canvas_from_dataframe,
_decorate_axs,
_get_collection_shape,
_get_colors_for_categorical_obs,
Expand Down Expand Up @@ -81,14 +82,15 @@ def _want_decorations(color_vector: Any, na_color: Color) -> bool:
cv = np.asarray(color_vector)
if cv.size == 0:
return False
unique_vals = set(cv.tolist())
if len(unique_vals) != 1:
# Fast check: if any value differs from the first, there is variety → show decorations.
first = cv.flat[0]
if not (cv == first).all():
return True
only_val = next(iter(unique_vals))
# All values are the same — suppress decorations when that value is the NA color.
na_hex = na_color.get_hex()
if isinstance(only_val, str) and only_val.startswith("#") and na_hex.startswith("#"):
return _hex_no_alpha(only_val) != _hex_no_alpha(na_hex)
return bool(only_val != na_hex)
if isinstance(first, str) and first.startswith("#") and na_hex.startswith("#"):
return _hex_no_alpha(first) != _hex_no_alpha(na_hex)
return bool(first != na_hex)


def _reparse_points(
Expand Down Expand Up @@ -828,15 +830,16 @@ def _render_points(
# use dpi/100 as a factor for cases where dpi!=100
px = int(np.round(np.sqrt(render_params.size) * (fig_params.fig.dpi / 100)))

# apply transformations
# Apply transformations and materialize to pandas immediately so
# datashader aggregates without dask scheduler overhead. See #379.
transformed_element = PointsModel.parse(
trans.transform(sdata_filt.points[element][["x", "y"]]),
annotation=sdata_filt.points[element][sdata_filt.points[element].columns.drop(["x", "y"])],
transformations={coordinate_system: Identity()},
)
).compute()

plot_width, plot_height, x_ext, y_ext, factor = _get_extent_and_range_for_datashader_canvas(
transformed_element, coordinate_system, ax, fig_params
plot_width, plot_height, x_ext, y_ext, factor = _datashader_canvas_from_dataframe(
transformed_element, ax, fig_params
)

# use datashader for the visualization of points
Expand Down Expand Up @@ -901,7 +904,7 @@ def _render_points(
and isinstance(color_vector[0], str)
and color_vector[0].startswith("#")
):
color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector])
color_vector = np.asarray([c[:7] if len(c) == 9 else c for c in color_vector])

nan_shaded = None
if color_by_categorical or col_for_color is None:
Expand Down
40 changes: 34 additions & 6 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2973,15 +2973,16 @@ def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = No
return ListedColormap(colors)


def _get_extent_and_range_for_datashader_canvas(
spatial_element: SpatialElement,
coordinate_system: str,
def _compute_datashader_canvas_params(
x_ext: list[Any],
y_ext: list[Any],
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
"""Compute datashader canvas dimensions from spatial extents.

Shared logic used by both the dask-based and pandas-based entry points.
"""
previous_xlim = ax.get_xlim()
previous_ylim = ax.get_ylim()
# increase range if sth larger was rendered on the axis before
Expand Down Expand Up @@ -3015,6 +3016,33 @@ def _get_extent_and_range_for_datashader_canvas(
return plot_width, plot_height, x_ext, y_ext, factor


def _get_extent_and_range_for_datashader_canvas(
spatial_element: SpatialElement,
coordinate_system: str,
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
extent = get_extent(spatial_element, coordinate_system=coordinate_system)
x_ext = [min(0, extent["x"][0]), extent["x"][1]]
y_ext = [min(0, extent["y"][0]), extent["y"][1]]
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)


def _datashader_canvas_from_dataframe(
df: pd.DataFrame,
ax: Axes,
fig_params: FigParams,
) -> tuple[Any, Any, list[Any], list[Any], Any]:
"""Compute datashader canvas params directly from a pandas DataFrame.

Avoids the overhead of ``get_extent()`` (which requires a dask-backed
SpatialElement) by reading min/max from the already-materialised data.
"""
x_ext = [min(0, float(df["x"].min())), float(df["x"].max())]
y_ext = [min(0, float(df["y"].min())), float(df["y"].max())]
return _compute_datashader_canvas_params(x_ext, y_ext, ax, fig_params)


def _create_image_from_datashader_result(
ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]],
factor: float,
Expand Down
23 changes: 23 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,26 @@ def test_datashader_alpha_not_applied_twice(sdata_blobs: SpatialData):
"on top of the alpha already in the RGBA channels — causing double transparency."
)
plt.close(fig)


def test_datashader_points_categorical_with_nan(sdata_blobs: SpatialData):
"""Datashader must handle categorical coloring with NaN values.

Regression test for https://github.com/scverse/spatialdata-plot/issues/379.
Exercises the optimised aggregation and color-key paths (pandas DataFrame
instead of dask, early-exit in _build_datashader_color_key).
"""
n = 200
rng = get_standard_RNG()
cats = pd.Categorical(rng.choice(["A", "B", None], n))
points = sdata_blobs["blobs_points"].compute().head(n).copy()
points["cat"] = cats.astype("object") # force object so PointsModel accepts it

sdata_blobs.points["test_pts"] = PointsModel.parse(points)

fig, ax = plt.subplots()
sdata_blobs.pl.render_points("test_pts", method="datashader", color="cat").pl.show(ax=ax)

axes_images = [c for c in ax.get_children() if isinstance(c, matplotlib.image.AxesImage)]
assert len(axes_images) > 0, "Datashader should produce at least one AxesImage"
plt.close(fig)
Loading