7070
7171
7272def _coerce_categorical_source (cat_source : Any ) -> pd .Categorical :
73- """Return a pandas Categorical from known, concrete sources only."""
73+ """Return a pandas Categorical from known, concrete sources only.
74+
75+ Raises
76+ ------
77+ TypeError
78+ If *cat_source* is not a ``dd.Series``, ``pd.Series``,
79+ ``pd.Categorical``, or ``np.ndarray``.
80+ """
7481 if isinstance (cat_source , dd .Series ):
75- if pd . api . types . is_categorical_dtype (cat_source .dtype ) and getattr (cat_source .cat , "known" , True ) is False :
82+ if isinstance (cat_source .dtype , pd . CategoricalDtype ) and getattr (cat_source .cat , "known" , True ) is False :
7683 cat_source = cat_source .cat .as_known ()
7784 cat_source = cat_source .compute ()
7885
7986 if isinstance (cat_source , pd .Series ):
80- if pd . api . types . is_categorical_dtype (cat_source .dtype ):
87+ if isinstance (cat_source .dtype , pd . CategoricalDtype ):
8188 return cat_source .array
8289 return pd .Categorical (cat_source )
8390 if isinstance (cat_source , pd .Categorical ):
8491 return cat_source
92+ if isinstance (cat_source , np .ndarray ):
93+ return pd .Categorical (cat_source )
8594
86- return pd .Categorical (pd .Series (cat_source ))
95+ raise TypeError (
96+ f"Cannot coerce { type (cat_source ).__name__ } to pd.Categorical. "
97+ "Expected dd.Series, pd.Series, pd.Categorical, or np.ndarray."
98+ )
8799
88100
89101def _build_datashader_color_key (
@@ -209,20 +221,6 @@ def _render_shapes(
209221
210222 values_are_categorical = color_source_vector is not None
211223
212- # When groups are specified and na_color is fully transparent (na_color=None),
213- # filter out non-matching elements instead of showing them as invisible geometry.
214- if groups is not None and values_are_categorical and render_params .cmap_params .na_color .alpha == "00" :
215- csv_series = pd .Series (color_source_vector )
216- keep = csv_series .isin (groups ).values
217- shapes = shapes [keep ].reset_index (drop = True )
218- sdata_filt [element ] = shapes
219- color_source_vector = pd .Categorical (csv_series [keep ].reset_index (drop = True ))
220- color_vector = (
221- np .asarray (color_vector )[keep ]
222- if not hasattr (color_vector , "reset_index" )
223- else (color_vector [keep ].reset_index (drop = True ))
224- )
225-
226224 # color_source_vector is None when the values aren't categorical
227225 if values_are_categorical and render_params .transfunc is not None :
228226 color_vector = render_params .transfunc (color_vector )
@@ -352,7 +350,7 @@ def _render_shapes(
352350 color_by_categorical = col_for_color is not None and color_source_vector is not None
353351 if color_by_categorical :
354352 cat_series = transformed_element [col_for_color ]
355- if not pd . api . types . is_categorical_dtype (cat_series ):
353+ if not isinstance (cat_series . dtype , pd . CategoricalDtype ):
356354 cat_series = cat_series .astype ("category" )
357355 transformed_element [col_for_color ] = cat_series
358356
@@ -842,29 +840,6 @@ def _render_points(
842840 )
843841 points_dd = points_with_color_dd
844842
845- # When groups are specified and na_color is fully transparent (na_color=None),
846- # filter out non-matching points instead of rendering invisible geometry.
847- if groups is not None and color_source_vector is not None and render_params .cmap_params .na_color .alpha == "00" :
848- csv_series = pd .Series (color_source_vector )
849- keep = csv_series .isin (groups ).values
850- color_source_vector = pd .Categorical (csv_series [keep ].reset_index (drop = True ))
851- color_vector = (
852- np .asarray (color_vector )[keep ]
853- if not hasattr (color_vector , "reset_index" )
854- else (color_vector [keep ].reset_index (drop = True ))
855- )
856- # filter the materialized points, adata, and re-register in sdata_filt
857- points = points [keep ].reset_index (drop = True )
858- adata = adata [keep ].copy ()
859- points_dd = dask .dataframe .from_pandas (points , npartitions = 1 )
860- sdata_filt .points [element ] = PointsModel .parse (points_dd , coordinates = {"x" : "x" , "y" : "y" })
861- set_transformation (
862- element = sdata_filt .points [element ],
863- transformation = transformation_in_cs ,
864- to_coordinate_system = coordinate_system ,
865- )
866- n_points = int (keep .sum ())
867-
868843 # color_source_vector is None when the values aren't categorical
869844 if color_source_vector is None and render_params .transfunc is not None :
870845 color_vector = render_params .transfunc (color_vector )
@@ -931,11 +906,11 @@ def _render_points(
931906 color_dtype = transformed_element [col_for_color ].dtype if col_for_color is not None else None
932907 color_by_categorical = col_for_color is not None and (
933908 color_source_vector is not None
934- or pd . api . types . is_categorical_dtype (color_dtype )
909+ or isinstance (color_dtype , pd . CategoricalDtype )
935910 or pd .api .types .is_object_dtype (color_dtype )
936911 or pd .api .types .is_string_dtype (color_dtype )
937912 )
938- if color_by_categorical and not pd . api . types . is_categorical_dtype (color_dtype ):
913+ if color_by_categorical and not isinstance (color_dtype , pd . CategoricalDtype ):
939914 transformed_element [col_for_color ] = transformed_element [col_for_color ].astype ("category" )
940915
941916 aggregate_with_reduction = None
@@ -944,7 +919,7 @@ def _render_points(
944919 if color_by_categorical :
945920 # add nan as category so that nan points are shown in the nan color
946921 cat_series = transformed_element [col_for_color ]
947- if not pd . api . types . is_categorical_dtype (cat_series ):
922+ if not isinstance (cat_series . dtype , pd . CategoricalDtype ):
948923 cat_series = cat_series .astype ("category" )
949924 if hasattr (cat_series .cat , "as_known" ):
950925 cat_series = cat_series .cat .as_known ()
0 commit comments