diff --git a/.gitignore b/.gitignore index 8bff5e8..7f88e5d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,7 @@ __pycache__/ # C extensions *.so - +myenv # Distribution / packaging .Python env/ diff --git a/docs/source/api.rst b/docs/source/api.rst index 5df6e57..0eaa3ed 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -19,4 +19,5 @@ IMAS-Python IDS manipulation ids_toplevel.IDSToplevel ids_primitive.IDSPrimitive ids_structure.IDSStructure + ids_slice.IDSSlice ids_struct_array.IDSStructArray diff --git a/docs/source/array_slicing.rst b/docs/source/array_slicing.rst new file mode 100644 index 0000000..1daf873 --- /dev/null +++ b/docs/source/array_slicing.rst @@ -0,0 +1,131 @@ +.. _array-slicing: + +Array Slicing +============= + +The ``IDSStructArray`` class supports Python's standard slicing syntax. + +Key Difference +--------------- + +- ``array[0]`` returns ``IDSStructure`` (single element) +- ``array[:]`` or ``array[1:5]`` returns ``IDSSlice`` (collection with ``values()`` method) + +Basic Usage +----------- + +.. code-block:: python + + import imas + + entry = imas.DBEntry("imas:hdf5?path=my-testdb") + cp = entry.get("core_profiles") + + # Integer indexing + first = cp.profiles_1d[0] # IDSStructure + last = cp.profiles_1d[-1] # IDSStructure + + # Slice operations + subset = cp.profiles_1d[1:5] # IDSSlice + every_other = cp.profiles_1d[::2] # IDSSlice + + # Access nested arrays + all_ions = cp.profiles_1d[:].ion[:] # IDSSlice of individual ions + + # Extract values + labels = all_ions.label.values() + +Multi-Dimensional Slicing +--------------------------- + +The ``IDSSlice`` class supports multi-dimensional shape tracking and array conversion. + +**Check shape of sliced data:** + +.. code-block:: python + + # Get shape information for multi-dimensional data + print(cp.profiles_1d[:].grid.shape) # (106,) + print(cp.profiles_1d[:].ion.shape) # (106, ~3) + print(cp.profiles_1d[1:3].ion[0].element.shape) # (2, ~3) + +**Extract values with shape preservation:** + +.. code-block:: python + + # Extract as list + grid_values = cp.profiles_1d[:].grid.values() + + # Extract as numpy array + grid_array = cp.profiles_1d[:].grid.to_array() + + # Extract as numpy array + ion_array = cp.profiles_1d[:].ion.to_array() + +**Nested structure access:** + +.. code-block:: python + + # Access through nested arrays + grid_data = cp.profiles_1d[1:3].grid.rho_tor.to_array() + + # Ion properties across multiple profiles + ion_labels = cp.profiles_1d[:].ion[:].label.to_array() + ion_charges = cp.profiles_1d[:].ion[:].z_ion.to_array() + +Common Patterns +--------------- + +**Process a range:** + +.. code-block:: python + + for element in cp.profiles_1d[5:10]: + print(element.time) + +**Iterate over nested arrays:** + +.. code-block:: python + + for ion in cp.profiles_1d[:].ion[:]: + print(ion.label.value) + +**Get all values:** + +.. code-block:: python + + times = cp.profiles_1d[:].time.values() + + # Or as numpy array + times_array = cp.profiles_1d[:].time.to_array() + +Important: Array-wise Indexing +------------------------------- + +When accessing attributes through a slice of ``IDSStructArray`` elements, +the slice operation automatically applies to each array (array-wise indexing): + +.. code-block:: python + + # Array-wise indexing: [:] applies to each ion array + all_ions = cp.profiles_1d[:].ion[:] + labels = all_ions.label.values() + + # Equivalent to manually iterating: + labels = [] + for profile in cp.profiles_1d[:]: + for ion in profile.ion: + labels.append(ion.label.value) + +Lazy-Loaded Arrays +------------------- + +Both individual indexing and slicing work with lazy loading: + +.. code-block:: python + + element = lazy_array[0] # OK - loads on demand + subset = lazy_array[1:5] # OK - loads only requested elements on demand + +When slicing lazy-loaded arrays, only the elements in the slice range are loaded, +making it memory-efficient for large datasets. diff --git a/docs/source/courses/advanced/explore.rst b/docs/source/courses/advanced/explore.rst index 7b383bc..02f1201 100644 --- a/docs/source/courses/advanced/explore.rst +++ b/docs/source/courses/advanced/explore.rst @@ -72,6 +72,32 @@ structures (modeled by :py:class:`~imas.ids_struct_array.IDSStructArray`) are (a name applies) arrays containing :py:class:`~imas.ids_structure.IDSStructure`\ s. Data nodes can contain scalar or array data of various types. +**Slicing Arrays of Structures** + +Arrays of structures support Python slice notation, which returns an +:py:class:`~imas.ids_slice.IDSSlice` object containing matched elements: + +.. code-block:: python + + import imas + + core_profiles = imas.IDSFactory().core_profiles() + core_profiles.profiles_1d.resize(10) # Create 10 profiles + + # Integer indexing returns a single structure + first = core_profiles.profiles_1d[0] + + # Slice notation returns an IDSSlice + subset = core_profiles.profiles_1d[2:5] # Elements 2, 3, 4 + every_other = core_profiles.profiles_1d[::2] # Every second element + + # IDSSlice supports array-wise indexing and values() for data access + all_ions = core_profiles.profiles_1d[:].ion[:] + for ion in all_ions: + print(ion.label.value) + +For detailed information on slicing operations, see :doc:`../../array_slicing`. + Some methods and properties are defined for all data nodes and arrays of structures: ``len()`` diff --git a/docs/source/imas_architecture.rst b/docs/source/imas_architecture.rst index 182d2a0..756d8f7 100644 --- a/docs/source/imas_architecture.rst +++ b/docs/source/imas_architecture.rst @@ -168,6 +168,12 @@ The following submodules and classes represent IDS nodes. :py:class:`~imas.ids_struct_array.IDSStructArray` class, which models Arrays of Structures. It also contains some :ref:`dev lazy loading` logic. +- :py:mod:`imas.ids_slice` contains the + :py:class:`~imas.ids_slice.IDSSlice` class, which represents a collection of IDS + nodes matching a slice expression. It provides slicing operations on + :py:class:`~imas.ids_struct_array.IDSStructArray` elements with array-wise + indexing and supports the ``values()`` method for extracting raw data. + - :py:mod:`imas.ids_structure` contains the :py:class:`~imas.ids_structure.IDSStructure` class, which models Structures. It contains the :ref:`lazy instantiation` logic and some of the :ref:`dev lazy loading` diff --git a/docs/source/index.rst b/docs/source/index.rst index 8388f5b..7b8f98f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,7 @@ Manual configuring cli netcdf + array_slicing changelog examples diff --git a/docs/source/intro.rst b/docs/source/intro.rst index 3027a24..125b407 100644 --- a/docs/source/intro.rst +++ b/docs/source/intro.rst @@ -154,3 +154,38 @@ can use ``.get()`` to load IDS data from disk: >>> dbentry2 = imas.DBEntry("mypulsefile.nc","r") >>> core_profiles2 = dbentry2.get("core_profiles") >>> print(core_profiles2.ids_properties.comment.value) + + +.. _`Multi-Dimensional Slicing`: + +Multi-Dimensional Slicing +'''''''''''''''''''''''''' + +IMAS-Python supports advanced slicing of hierarchical data structures with automatic +shape tracking and array conversion to numpy. This enables intuitive access to +multi-dimensional scientific data: + +.. code-block:: python + + >>> # Load data + >>> entry = imas.DBEntry("mypulsefile.nc","r") + >>> cp = entry.get("core_profiles", autoconvert=False, lazy=True) + + >>> # Check shape of sliced data + >>> cp.profiles_1d[:].grid.shape + (106,) + >>> cp.profiles_1d[:].ion.shape + (106, ~3) # ~3 ions per profile + + >>> # Extract values + >>> grid_values = cp.profiles_1d[:].grid.to_array() + >>> ion_labels = cp.profiles_1d[:].ion[:].label.to_array() + + >>> # Work with subsets + >>> subset_grid = cp.profiles_1d[1:3].grid.to_array() + >>> subset_ions = cp.profiles_1d[1:3].ion.to_array() + +The ``IDSSlice`` class tracks multi-dimensional shapes and provides both +``.values()`` and ``.to_array()`` (numpy array) +methods for data extraction. For more details, see :ref:`array-slicing`. + diff --git a/imas/ids_slice.py b/imas/ids_slice.py new file mode 100644 index 0000000..30bc22a --- /dev/null +++ b/imas/ids_slice.py @@ -0,0 +1,572 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. +"""IDSSlice represents a collection of IDS nodes matching a slice expression. + +This module provides the IDSSlice class, which enables slicing of arrays of +structures while maintaining the hierarchy and allowing further operations on +the resulting collection. +""" + +import logging +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union + +import numpy as np + +from imas.ids_metadata import IDSMetadata + +if TYPE_CHECKING: + from imas.ids_struct_array import IDSStructArray + +logger = logging.getLogger(__name__) + + +class IDSSlice: + """Represents a slice of IDS struct array elements. + + When slicing an IDSStructArray, instead of returning a regular Python list, + an IDSSlice is returned. This allows for: + - Tracking the slice operation in the path + - Further slicing of child elements + - Child node access on all matched elements + - Iteration over matched elements + + Attributes: + metadata: Metadata from the parent array, or None if not available + """ + + __slots__ = [ + "metadata", + "_matched_elements", + "_slice_path", + "_parent_array", + "_virtual_shape", + "_element_hierarchy", + ] + + def __init__( + self, + metadata: Optional[IDSMetadata], + matched_elements: List[Any], + slice_path: str, + parent_array: Optional["IDSStructArray"] = None, + virtual_shape: Optional[Tuple[int, ...]] = None, + element_hierarchy: Optional[List[Any]] = None, + ): + """Initialize IDSSlice. + + Args: + metadata: Metadata from the parent array + matched_elements: List of elements that matched the slice + slice_path: String representation of the slice operation (e.g., "[8:]") + parent_array: Optional reference to the parent IDSStructArray for context + virtual_shape: Optional tuple representing multi-dimensional shape + element_hierarchy: Optional tracking of element grouping + """ + self.metadata = metadata + self._matched_elements = matched_elements + self._slice_path = slice_path + self._parent_array = parent_array + self._virtual_shape = virtual_shape or (len(matched_elements),) + self._element_hierarchy = element_hierarchy or [len(matched_elements)] + + @property + def _path(self) -> str: + """Return the path representation of this slice.""" + return self._slice_path + + @property + def shape(self) -> Tuple[int, ...]: + """Get the virtual multi-dimensional shape. + + Returns the shape of the data as if it were organized in a multi-dimensional + array, based on the hierarchy of slicing operations performed. + + Returns: + Tuple of dimensions. + """ + return self._virtual_shape + + def __len__(self) -> int: + """Return the number of elements matched by this slice.""" + return len(self._matched_elements) + + def __iter__(self) -> Iterator[Any]: + """Iterate over all matched elements.""" + return iter(self._matched_elements) + + def __getitem__(self, item: Union[int, slice]) -> Union[Any, "IDSSlice"]: + """Get element(s) from the slice. + + When the matched elements are IDSStructArray objects, the indexing + operation is applied to each array element (array-wise indexing). + Otherwise, the operation is applied to the matched elements list itself. + + Args: + item: Index or slice to apply + + Returns: + - IDSSlice: If item is a slice, or if applying integer index to + IDSStructArray elements + - Single element: If item is an int and elements are not IDSStructArray + """ + from imas.ids_struct_array import IDSStructArray + + # Array-wise indexing: apply operation to each IDSStructArray element + if self._matched_elements and isinstance( + self._matched_elements[0], IDSStructArray + ): + if isinstance(item, slice): + # Preserve structure instead of flattening + sliced_elements = [] + sliced_sizes = [] + + for array in self._matched_elements: + sliced = array[item] + if isinstance(sliced, IDSSlice): + sliced_elements.extend(sliced._matched_elements) + sliced_sizes.append(len(sliced)) + else: + sliced_elements.append(sliced) + sliced_sizes.append(1) + + slice_str = self._format_slice(item) + new_path = self._slice_path + slice_str + + # Update shape to reflect the sliced structure + # Keep first dimensions, update last dimension + new_virtual_shape = self._virtual_shape[:-1] + ( + sliced_sizes[0] if sliced_sizes else 0, + ) + new_hierarchy = self._element_hierarchy[:-1] + [sliced_sizes] + + return IDSSlice( + self.metadata, + sliced_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_hierarchy, + ) + else: + # Integer indexing on arrays + indexed_elements = [] + for array in self._matched_elements: + indexed_elements.append(array[int(item)]) + + new_path = self._slice_path + f"[{item}]" + + # Shape changes: last dimension becomes 1 + new_virtual_shape = self._virtual_shape[:-1] + (1,) + + return IDSSlice( + self.metadata, + indexed_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=self._element_hierarchy, + ) + else: + if isinstance(item, slice): + sliced_elements = self._matched_elements[item] + slice_str = self._format_slice(item) + new_path = self._slice_path + slice_str + + # Update shape to reflect the slice on first dimension + new_virtual_shape = (len(sliced_elements),) + self._virtual_shape[1:] + new_element_hierarchy = [ + len(sliced_elements) + ] + self._element_hierarchy[1:] + + return IDSSlice( + self.metadata, + sliced_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_element_hierarchy, + ) + else: + return self._matched_elements[int(item)] + + def __getattr__(self, name: str) -> "IDSSlice": + """Access a child node on all matched elements. + + This returns a new IDSSlice containing the child node from + each matched element. Preserves multi-dimensional structure + when child elements are arrays. + + Args: + name: Name of the node to access + + Returns: + A new IDSSlice containing the child node from each matched element + """ + if not self._matched_elements: + raise IndexError( + f"Cannot access node '{name}' on empty slice with 0 elements" + ) + + from imas.ids_struct_array import IDSStructArray + + child_metadata = None + if self.metadata is not None: + try: + child_metadata = self.metadata[name] + except (KeyError, TypeError): + pass + + child_elements = [getattr(element, name) for element in self] + new_path = self._slice_path + "." + name + + # Check if children are IDSStructArray (nested arrays) or IDSNumericArray + if not child_elements: + # Empty slice + return IDSSlice( + child_metadata, + child_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=self._virtual_shape, + element_hierarchy=self._element_hierarchy, + ) + + from imas.ids_primitive import IDSNumericArray + + if isinstance(child_elements[0], IDSStructArray): + # Children are IDSStructArray - track the new dimension + child_sizes = [len(arr) for arr in child_elements] + + # New virtual shape: current shape + new dimension + new_virtual_shape = self._virtual_shape + ( + child_sizes[0] if child_sizes else 0, + ) + new_hierarchy = self._element_hierarchy + [child_sizes] + + return IDSSlice( + child_metadata, + child_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_hierarchy, + ) + elif isinstance(child_elements[0], IDSNumericArray): + # Children are IDSNumericArray - track the array dimension + # Each IDSNumericArray has a size (length of its data) + child_sizes = [len(arr) for arr in child_elements] + + # New virtual shape: current shape + new dimension + new_virtual_shape = self._virtual_shape + ( + child_sizes[0] if child_sizes else 0, + ) + new_hierarchy = self._element_hierarchy + [child_sizes] + + return IDSSlice( + child_metadata, + child_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=new_virtual_shape, + element_hierarchy=new_hierarchy, + ) + else: + # Children are not arrays (structures or other primitives) + return IDSSlice( + child_metadata, + child_elements, + new_path, + parent_array=self._parent_array, + virtual_shape=self._virtual_shape, + element_hierarchy=self._element_hierarchy, + ) + + def __repr__(self) -> str: + """Build a string representation of this IDSSlice. + + Returns a string showing: + - The IDS type name (e.g., 'equilibrium') + - The full path including the slice operation (e.g., 'time_slice[:]') + - The number of matched elements + + Returns: + String representation like below + like '' + """ + from imas.util import get_toplevel, get_full_path + + my_repr = f"<{type(self).__name__}" + ids_name = "unknown" + full_path = self._path + + if self._parent_array is not None: + ids_name = get_toplevel(self._parent_array).metadata.name + parent_array_path = get_full_path(self._parent_array) + full_path = parent_array_path + self._path + item_word = "item" if len(self) == 1 else "items" + my_repr += f" (IDS:{ids_name}, {full_path} with {len(self)} {item_word})>" + return my_repr + + def values(self, reshape: bool = False) -> Any: + """Extract raw values from elements in this slice. + + For IDSPrimitive elements, this extracts the wrapped value. + For other element types, returns them as-is. + + For multi-dimensional slices (when shape has multiple dimensions), + this extracts values respecting the multi-dimensional structure. + + This is useful for getting the actual data without the IDS wrapper + when accessing scalar fields through a slice, without requiring + explicit looping through the original collection. + + Args: + reshape: If True, reshape result to match self.shape for + multi-dimensional slices. If False (default), return flat list + or list of extracted values. + + Returns: + list or numpy.ndarray: Extracted values as follows: + + - 1D slices: List of raw Python/numpy values or unwrapped elements + - Multi-D with reshape=False: List of elements (each being an array) + - Multi-D with reshape=True: numpy.ndarray with shape self.shape, + or nested lists/object array representing structure + + Examples: + >>> # Get names from identifiers without looping + >>> n = edge_profiles.grid_ggd[0].grid_subset[:].identifier.name.values() + >>> # Result: ["nodes", "edges", "cells"] + >>> + >>> # Get 2D array but as list of arrays (default) + >>> rho = core_profiles.profiles_1d[:].grid.rho_tor.values() + >>> # Result: [ndarray(100,), ndarray(100,), ...] - list of 106 arrays + >>> + >>> # Get 2D array reshaped to (106, 100) + >>> rho = core_profiles.profiles_1d[:].grid.rho_tor.values(reshape=True) + >>> # Result: ndarray shape (106, 100) + >>> + >>> # 3D ions case - returns object array with structure + >>> ion_rho = ( + ... core_profiles.profiles_1d[:].ion[:].element[:].density.values( + ... reshape=True + ... ) + ... ) + >>> # Result: object array shape (106, 3, 2) with IDSNumericArray elements + """ + from imas.ids_primitive import IDSPrimitive, IDSNumericArray + + # Default behavior: return flat list without reshape + if not reshape: + result = [] + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + result.append(element.value) + else: + result.append(element) + return result + + # Multi-dimensional case with reshape requested + flat_values = [] + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + flat_values.append(element.value) + elif isinstance(element, IDSNumericArray): + flat_values.append( + element.data if hasattr(element, "data") else element.value + ) + else: + flat_values.append(element) + + # For 1D, just return as is + if len(self._virtual_shape) == 1: + return flat_values + + # Try to reshape to multi-dimensional shape + try: + # Calculate total size + total_size = 1 + for dim in self._virtual_shape: + total_size *= dim + + # Check if sizes match + if len(flat_values) == total_size: + # Successfully reshape to multi-dimensional + return np.array(flat_values, dtype=object).reshape(self._virtual_shape) + except (ValueError, TypeError): + pass + + # If reshape fails or not all elements are extractable, return as object array + try: + return np.array(flat_values, dtype=object).reshape(self._virtual_shape[0:1]) + except (ValueError, TypeError): + return flat_values + + def to_array(self) -> np.ndarray: + """Convert this slice to a numpy array respecting multi-dimensional structure. + + For 1D slices: returns a simple 1D array. + For multi-dimensional slices: returns an array with shape self.shape. + + This is useful for integration with numpy operations, scipy functions, + and xarray data structures. The returned array preserves the hierarchical + structure of the IMAS data. + + Returns: + numpy.ndarray with shape self.shape. + + Raises: + ValueError: If array cannot be converted to numpy + + Examples: + >>> # Convert 2D slice to numpy array + >>> rho_array = core_profiles.profiles_1d[:].grid.rho_tor.to_array() + >>> # Result: ndarray shape (106, 100), dtype float64 + >>> print(rho_array.shape) + (106, 100) + >>> + >>> ion_density = core_profiles.profiles_1d[:].ion[:].density.to_array() + >>> # Result: object array shape (106, 3) with varying sizes + >>> + >>> # Can be used directly with numpy functions + >>> mean_rho = np.mean(rho_array, axis=1) + >>> # Result: (106,) array of mean values + """ + from imas.ids_primitive import IDSPrimitive, IDSNumericArray + + # 1D case - simple conversion + if len(self._virtual_shape) == 1: + flat_values = [] + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + flat_values.append(element.value) + else: + flat_values.append(element) + try: + return np.array(flat_values) + except (ValueError, TypeError): + return np.array(flat_values, dtype=object) + + # Multi-dimensional case + # Check if matched elements are themselves arrays (IDSNumericArray) + if self._matched_elements and isinstance( + self._matched_elements[0], IDSNumericArray + ): + # Elements are numeric arrays - extract their values and stack them + array_values = [] + for element in self._matched_elements: + if isinstance(element, IDSNumericArray): + array_values.append(element.value) + else: + array_values.append(element) + + # Try to stack into proper shape + try: + # Check if all arrays have the same size (regular) + sizes = [] + for val in array_values: + if hasattr(val, "__len__"): + sizes.append(len(val)) + else: + sizes.append(1) + + # If all sizes are the same, we can create a regular array + if len(set(sizes)) == 1: + # Regular array - all sub-arrays same size + stacked = np.array(array_values) + # Should now have shape (first_dim, second_dim) + if stacked.shape == self._virtual_shape: + return stacked + else: + # Try explicit reshape + try: + return stacked.reshape(self._virtual_shape) + except ValueError: + # If reshape fails, return as object array + result_arr = np.empty(self._virtual_shape, dtype=object) + for i, val in enumerate(array_values): + result_arr.flat[i] = val + return result_arr + else: + result_arr = np.empty(self._virtual_shape[0], dtype=object) + for i, val in enumerate(array_values): + result_arr[i] = val + return result_arr + except (ValueError, TypeError): + # Fallback: return object array + result_arr = np.empty(self._virtual_shape[0], dtype=object) + for i, val in enumerate(array_values): + result_arr[i] = val + return result_arr + + # For non-numeric elements in multi-dimensional structure + # Extract and try to build structure + flat_values = [] + + # First check if matched_elements are IDSStructArray (which need flattening) + from imas.ids_struct_array import IDSStructArray + + has_struct_arrays = self._matched_elements and isinstance( + self._matched_elements[0], IDSStructArray + ) + + if has_struct_arrays: + # Flatten IDSStructArray elements + for struct_array in self._matched_elements: + for element in struct_array: + if isinstance(element, IDSPrimitive): + flat_values.append(element.value) + else: + flat_values.append(element) + else: + # Regular elements + for element in self._matched_elements: + if isinstance(element, IDSPrimitive): + flat_values.append(element.value) + else: + flat_values.append(element) + + total_size = 1 + for dim in self._virtual_shape: + total_size *= dim + + # Check if we have the right number of elements + if len(flat_values) != total_size: + raise ValueError( + f"Cannot convert to array: expected {total_size} elements " + f"but got {len(flat_values)}" + ) + + # Try to create the array + try: + arr = np.array(flat_values) + try: + # Try to reshape to target shape + return arr.reshape(self._virtual_shape) + except (ValueError, TypeError): + # If reshape fails, use object array + arr_obj = np.empty(self._virtual_shape, dtype=object) + for i, val in enumerate(flat_values): + arr_obj.flat[i] = val + return arr_obj + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to convert slice to numpy array: {e}") + + @staticmethod + def _format_slice(slice_obj: slice) -> str: + """Format a slice object as a string. + + Args: + slice_obj: The slice object to format + + Returns: + String representation like "[1:5]", "[::2]", etc. + """ + start = slice_obj.start if slice_obj.start is not None else "" + stop = slice_obj.stop if slice_obj.stop is not None else "" + step = slice_obj.step if slice_obj.step is not None else "" + + if step: + return f"[{start}:{stop}:{step}]" + else: + return f"[{start}:{stop}]" diff --git a/imas/ids_struct_array.py b/imas/ids_struct_array.py index b176864..b06396b 100644 --- a/imas/ids_struct_array.py +++ b/imas/ids_struct_array.py @@ -121,12 +121,56 @@ def _element_structure(self): return struct def __getitem__(self, item): - # value is a list, so the given item should be convertable to integer - # TODO: perhaps we should allow slices as well? - list_idx = int(item) - if self._lazy: - self._load(item) - return self.value[list_idx] + """Get element(s) from the struct array. + + Args: + item: Integer index or slice object + + Returns: + A single IDSStructure if item is an int, or an IDSSlice if item is a slice + """ + if isinstance(item, slice): + if self._lazy: + + self._load(None) # Load size + + # Convert slice to indices + start, stop, step = item.indices(len(self)) + + # Load only the elements in the slice range + loaded_elements = [] + for i in range(start, stop, step): + self._load(i) # Load each element on demand + loaded_elements.append(self.value[i]) + + from imas.ids_slice import IDSSlice + + slice_str = IDSSlice._format_slice(item) + + return IDSSlice( + self.metadata, + loaded_elements, + slice_str, + parent_array=self, + ) + + from imas.ids_slice import IDSSlice + + matched_elements = self.value[item] + slice_str = IDSSlice._format_slice(item) + + return IDSSlice( + self.metadata, + matched_elements, + slice_str, + parent_array=self, + ) + else: + # Handle integer index + list_idx = int(item) + if self._lazy: + self._load(item) + return self.value[list_idx] def __setitem__(self, item, value): # value is a list, so the given item should be convertable to integer diff --git a/imas/test/test_ids_slice.py b/imas/test/test_ids_slice.py new file mode 100644 index 0000000..9bb27e5 --- /dev/null +++ b/imas/test/test_ids_slice.py @@ -0,0 +1,461 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. + +import numpy as np +import pytest + +from imas.ids_factory import IDSFactory +from imas.ids_slice import IDSSlice + + +@pytest.fixture +def wall_with_units(): + return create_wall_with_units() + + +@pytest.fixture +def wall_varying_sizes(): + return create_wall_with_units(total_units=2, element_counts=[4, 2]) + + +def create_wall_with_units( + total_units: int = 12, + element_counts=None, + *, + dd_version: str = "3.39.0", +): + + if total_units < 2: + raise ValueError("Need at least two units to exercise slice edge cases.") + + wall = IDSFactory(dd_version).wall() + wall.description_2d.resize(1) + + units = wall.description_2d[0].vessel.unit + units.resize(total_units) + + if element_counts is None: + element_counts = [4, 2] + [3] * (total_units - 2) + + element_counts = list(element_counts) + if len(element_counts) != total_units: + raise ValueError("element_counts length must match total_units.") + + for unit_idx, unit in enumerate(units): + unit.name = f"unit-{unit_idx}" + unit.element.resize(element_counts[unit_idx]) + for elem_idx, element in enumerate(unit.element): + element.name = f"element-{unit_idx}-{elem_idx}" + + return wall + + +def safe_element_lookup(units_slice, element_index: int): + collected = [] + skipped_units = [] + for idx, unit in enumerate(units_slice): + elements = unit.element + if element_index >= len(elements): + skipped_units.append(idx) + continue + collected.append(elements[element_index].name.value) + return {"collected": collected, "skipped_units": skipped_units} + + +class TestBasicSlicing: + + def test_slice_with_start_and_stop(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[3:7] + assert isinstance(result, IDSSlice) + assert len(result) == 4 + + result = cp.profiles_1d[::2] + assert isinstance(result, IDSSlice) + assert len(result) == 5 + + result = cp.profiles_1d[-5:] + assert isinstance(result, IDSSlice) + assert len(result) == 5 + + def test_slice_corner_cases(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[0:100] + assert len(result) == 10 + + result = cp.profiles_1d[10:20] + assert len(result) == 0 + + result = cp.profiles_1d[::-1] + assert len(result) == 10 + + def test_integer_index_still_works(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[5] + assert not isinstance(result, IDSSlice) + assert hasattr(result, "_path") + + +class TestIDSSlicePath: + + def test_slice_path_representation(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[5:8] + expected_path = "[5:8]" + assert expected_path in result._path + + result = cp.profiles_1d[5:8][1:3] + assert "[" in result._path + + def test_attribute_access_path(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[8:] + + element_slice = units.element + assert "element" in element_slice._path + + +class TestIDSSliceIteration: + + def test_iteration_and_len(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + + slice_obj = cp.profiles_1d[1:4] + + items = list(slice_obj) + assert len(items) == 3 + + assert len(slice_obj) == 3 + + +class TestIDSSliceIndexing: + + def test_integer_indexing_slice(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[3:7] + element = slice_obj[1] + assert not isinstance(element, IDSSlice) + + def test_slice_indexing_slice(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[2:8] + nested_slice = slice_obj[1:4] + assert isinstance(nested_slice, IDSSlice) + assert len(nested_slice) == 3 + + +class TestIDSSliceAttributeAccess: + + def test_attribute_access_nested_attributes(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[8:] + + names = units.name + assert isinstance(names, IDSSlice) + assert len(names) == 4 + + units_full = wall.description_2d[0].vessel.unit + elements = units_full[:].element + assert isinstance(elements, IDSSlice) + + +class TestIDSSliceRepr: + + def test_repr_count_display(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + slice_obj = cp.profiles_1d[5:6] + repr_str = repr(slice_obj) + assert "IDSSlice" in repr_str + assert "1 item" in repr_str + + slice_obj = cp.profiles_1d[5:8] + repr_str = repr(slice_obj) + assert "IDSSlice" in repr_str + assert "3 items" in repr_str + + +class TestWallExampleSlicing: + + def test_wall_units_nested_element_access(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + units_slice = units[8:] + assert isinstance(units_slice, IDSSlice) + assert len(units_slice) == 4 + + elements_slice = units_slice.element + assert isinstance(elements_slice, IDSSlice) + + +class TestEdgeCases: + + def test_slice_empty_array(self): + cp = IDSFactory("3.39.0").core_profiles() + + result = cp.profiles_1d[:] + assert isinstance(result, IDSSlice) + assert len(result) == 0 + + def test_slice_single_element(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(1) + + result = cp.profiles_1d[:] + assert isinstance(result, IDSSlice) + assert len(result) == 1 + + def test_invalid_step_zero(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + with pytest.raises(ValueError): + cp.profiles_1d[::0] + + +class TestFlatten: + + def test_flatten_basic_and_partial(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(5) + + slice_obj = cp.profiles_1d[:].ion + flattened = slice_obj[:] + assert isinstance(flattened, IDSSlice) + assert len(flattened) == 15 + + cp2 = IDSFactory("3.39.0").core_profiles() + cp2.profiles_1d.resize(4) + for profile in cp2.profiles_1d: + profile.ion.resize(3) + flattened2 = cp2.profiles_1d[:2].ion[:] + assert len(flattened2) == 6 + + def test_flatten_empty_and_single(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + empty_flattened = cp.profiles_1d[:].ion[:] + assert len(empty_flattened) == 0 + + cp2 = IDSFactory("3.39.0").core_profiles() + cp2.profiles_1d.resize(1) + cp2.profiles_1d[0].ion.resize(4) + single_flattened = cp2.profiles_1d[:].ion[:] + assert len(single_flattened) == 4 + + def test_flatten_indexing_and_slicing(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + + for i, profile in enumerate(cp.profiles_1d): + profile.ion.resize(3) + for j, ion in enumerate(profile.ion): + ion.label = f"ion_{i}_{j}" + + flattened = cp.profiles_1d[:].ion[:] + + assert flattened[0].label == "ion_0_0" + assert flattened[3].label == "ion_1_0" + + subset = flattened[1:4] + assert isinstance(subset, IDSSlice) + assert len(subset) == 3 + labels = [ion.label for ion in subset] + assert labels == ["ion_0_1", "ion_0_2", "ion_1_0"] + + def test_flatten_repr_and_path(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + for profile in cp.profiles_1d: + profile.ion.resize(2) + + flattened = cp.profiles_1d[:].ion[:] + repr_str = repr(flattened) + + assert "IDSSlice" in repr_str + assert "4 items" in repr_str + assert "[:]" in flattened._path + + def test_flatten_complex_case(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit[:5] + + all_elements = units.element[:] + assert len(all_elements) == 4 + 2 + 3 + 3 + 3 + + +class TestVaryingArraySizeIndexing: + + def test_unit_slice_element_integer_indexing(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + units_slice = units[:2] + element_slice = units_slice.element + + with pytest.raises(IndexError): + element_slice[2] + + def test_unit_slice_element_safe_indexing_scenarios(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + units_slice = units[:2] + + result = safe_element_lookup(units_slice, 1) + assert len(result["collected"]) == 2 + assert result["collected"] == ["element-0-1", "element-1-1"] + + result = safe_element_lookup(units_slice, 2) + assert len(result["collected"]) == 1 + assert result["skipped_units"] == [1] + + result = safe_element_lookup(units_slice, 4) + assert len(result["collected"]) == 0 + assert result["skipped_units"] == [0, 1] + + def test_unit_slice_element_individual_access(self, wall_varying_sizes): + units = wall_varying_sizes.description_2d[0].vessel.unit + element_slice = units[:2].element + + first_from_each = element_slice[0] + assert isinstance(first_from_each, IDSSlice) + assert len(first_from_each) == 2 + + arrays = list(element_slice) + assert len(arrays[0]) == 4 + assert arrays[0][2].name.value == "element-0-2" + + assert len(arrays[1]) == 2 + + with pytest.raises(IndexError): + arrays[1][2] + + def test_wall_with_diverse_element_counts(self): + wall = create_wall_with_units(total_units=5, element_counts=[3, 1, 4, 2, 5]) + + units = wall.description_2d[0].vessel.unit + units_slice = units[:3] + element_slice = units_slice.element + + first_from_each = element_slice[0] + assert isinstance(first_from_each, IDSSlice) + assert len(first_from_each) == 3 + + arrays = list(element_slice) + assert len(arrays[0]) == 3 + assert len(arrays[2]) == 4 + + result = safe_element_lookup(units_slice, 2) + assert len(result["collected"]) == 2 + assert result["skipped_units"] == [1] + + +class TestIDSSliceValues: + + def test_values_basic_extraction(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_slice = units[:].name + names = names_slice.values() + + assert isinstance(names, list) + assert len(names) == 12 + assert all(isinstance(name, str) and name.startswith("unit-") for name in names) + assert names == [f"unit-{i}" for i in range(12)] + + def test_values_integer_and_float_extraction(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(2) + for i, ion in enumerate(profile.ion): + ion.neutral_index = i + ion.z_ion = float(i + 1) + + ions = cp.profiles_1d[:].ion[:] + indices = ions[:].neutral_index.values() + assert all(isinstance(idx, (int, np.integer)) for idx in indices) + + z_values = ions[:].z_ion.values() + assert all(isinstance(z, (float, np.floating)) for z in z_values) + + def test_values_partial_and_empty_slices(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names = units[:5].name.values() + assert len(names) == 5 + assert names == [f"unit-{i}" for i in range(5)] + + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + # Empty slices should raise IndexError when accessing attributes + with pytest.raises(IndexError): + cp.profiles_1d[5:10].label.values() + + def test_values_with_step_and_negative_indices(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_step = units[::2].name.values() + assert len(names_step) == 6 + assert names_step == [f"unit-{i}" for i in range(0, 12, 2)] + + names_neg = units[-3:].name.values() + assert len(names_neg) == 3 + assert names_neg == [f"unit-{i}" for i in range(9, 12)] + + def test_values_structure_preservation(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + + for profile in cp.profiles_1d: + profile.ion.resize(2) + + ions = cp.profiles_1d[:].ion[:].values() + + assert len(ions) == 6 + for ion in ions: + assert hasattr(ion, "_path") + from imas.ids_primitive import IDSPrimitive + + assert not isinstance(ion, IDSPrimitive) + + def test_values_array_primitives(self): + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + + cp.profiles_1d[0].grid.psi = np.linspace(0, 1, 10) + cp.profiles_1d[1].grid.psi = np.linspace(1, 2, 10) + + psi_values = cp.profiles_1d[:].grid.psi.values() + + assert len(psi_values) == 2 + assert all(isinstance(psi, np.ndarray) for psi in psi_values) + + def test_values_consistency_with_iteration(self, wall_with_units): + wall = wall_with_units + units = wall.description_2d[0].vessel.unit + + names_via_values = units[:5].name.values() + + names_via_iteration = [unit.name.value for unit in units[:5]] + + assert names_via_values == names_via_iteration diff --git a/imas/test/test_ids_struct_array.py b/imas/test/test_ids_struct_array.py index ab128df..8c31f22 100644 --- a/imas/test/test_ids_struct_array.py +++ b/imas/test/test_ids_struct_array.py @@ -87,3 +87,15 @@ def test_struct_array_eq(): assert cp1.profiles_1d != cp2.profiles_1d cp2.profiles_1d[0].time = 1 assert cp1.profiles_1d == cp2.profiles_1d + + +def test_struct_array_slice(): + cp1 = IDSFactory("3.39.0").core_profiles() + cp1.profiles_1d.resize(20) + + assert len(cp1.profiles_1d) == 20 + assert len(cp1.profiles_1d[:]) == 20 + assert len(cp1.profiles_1d[5:10]) == 5 + assert len(cp1.profiles_1d[10:]) == 10 + assert len(cp1.profiles_1d[:5]) == 5 + assert len(cp1.profiles_1d[::2]) == 10 diff --git a/imas/test/test_multidim_slicing.py b/imas/test/test_multidim_slicing.py new file mode 100644 index 0000000..f5fbdae --- /dev/null +++ b/imas/test/test_multidim_slicing.py @@ -0,0 +1,349 @@ +# This file is part of IMAS-Python. +# You should have received the IMAS-Python LICENSE file with this project. +"""Tests for multi-dimensional slicing support in IDSSlice.""" + +import numpy as np +import pytest + +from imas.ids_factory import IDSFactory + + +class TestMultiDimSlicing: + """Shape tracking and conversion methods.""" + + def test_shape_property_single_level(self): + """Test shape property for single-level slice.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + + result = cp.profiles_1d[:] + assert hasattr(result, "shape") + assert result.shape == (10,) + + def test_shape_property_two_level(self): + """Test shape property for 2D array access.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + assert result.shape == (5, 3) + + def test_shape_property_three_level(self): + """Test shape property for 3D nested structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.element.resize(2) + + result = cp.profiles_1d[:].ion[:].element[:] + assert result.shape == (3, 2, 2) + + def test_to_array_2d_regular(self): + """Test to_array() with regular 2D array.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + + assert isinstance(array, np.ndarray) + assert array.shape == (5, 3) + assert np.allclose(array[0], [0.0, 0.5, 1.0]) + assert np.allclose(array[4], [0.0, 0.5, 1.0]) + + def test_to_array_3d_regular(self): + """Test to_array() with regular 3D array.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i_idx, i in enumerate(p.ion): + i.element.resize(2) + for e_idx, e in enumerate(i.element): + e.z_n = float(e_idx) + + result = cp.profiles_1d[:].ion[:].element[:].z_n + array = result.to_array() + + assert isinstance(array, np.ndarray) + assert array.shape == (3, 2, 2) + assert np.allclose(array[0, 0, :], [0.0, 1.0]) + assert np.allclose(array[0, 1, :], [0.0, 1.0]) + + def test_to_array_variable_size(self): + """Test to_array() with variable-size arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + cp.profiles_1d[0].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + cp.profiles_1d[1].grid.rho_tor_norm = np.array([0.0, 0.25, 0.5, 0.75, 1.0]) + cp.profiles_1d[2].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + + assert array.dtype == object + assert len(array) == 3 + assert len(array[0]) == 3 + assert len(array[1]) == 5 + assert len(array[2]) == 3 + + def test_enhanced_values_2d(self): + """Test enhanced values() method for 2D extraction.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + values = result.values() + + # Should be a list of 3 arrays + assert isinstance(values, list) + assert len(values) == 3 + for v in values: + assert isinstance(v, np.ndarray) + assert len(v) == 3 + + def test_enhanced_values_3d(self): + """Test enhanced values() method for 3D extraction.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(2) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.element.resize(2) + for e_idx, e in enumerate(i.element): + e.z_n = float(e_idx) + + result = cp.profiles_1d[:].ion[:].element[:].z_n + values = result.values() + + assert isinstance(values, list) + assert len(values) == 8 # 2 profiles * 2 ions * 2 elements + + def test_slice_preserves_groups(self): + """Test that slicing preserves group structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + for p in cp.profiles_1d: + p.ion.resize(3) + + # Get all ions, then slice + result = cp.profiles_1d[:].ion[:] + + # Should still know the structure: 10 profiles, 3 ions each + assert result.shape == (10, 3) + assert len(result) == 30 # Flattened for iteration, but shape preserved + + def test_integer_index_on_nested(self): + """Test integer indexing on nested structures.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.ion.resize(2) + for j, ion in enumerate(p.ion): + ion.label = f"ion_{i}_{j}" + + # Get first ion from all profiles + result = cp.profiles_1d[:].ion[0] + + assert len(result) == 5 + for i, ion in enumerate(result): + assert ion.label == f"ion_{i}_0" + + def test_slice_on_nested_arrays(self): + """Test slicing on nested arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(4) + + # Get first 2 ions from each profile + result = cp.profiles_1d[:].ion[:2] + + assert result.shape == (5, 2) + assert len(result) == 10 # 5 profiles * 2 ions each + + def test_step_slicing_on_nested(self): + """Test step slicing on nested structures.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(6) + + # Get every other ion + result = cp.profiles_1d[:].ion[::2] + + assert result.shape == (5, 3) # 5 profiles, 3 ions each (0, 2, 4) + assert len(result) == 15 + + def test_negative_indexing_on_nested(self): + """Test negative indexing on nested structures.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(3) + for j, ion in enumerate(p.ion): + ion.label = f"ion_{j}" + + # Get last ion from each profile + result = cp.profiles_1d[:].ion[-1] + + assert len(result) == 5 + for ion in result: + assert ion.label == "ion_2" + + def test_to_array_grouped_structure(self): + """Test that to_array preserves grouped structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p_idx, p in enumerate(cp.profiles_1d): + p.ion.resize(2) + for i_idx, i in enumerate(p.ion): + i.z_ion = float(p_idx * 10 + i_idx) + + result = cp.profiles_1d[:].ion[:].z_ion + array = result.to_array() + + # Should be (3, 2) array + assert array.shape == (3, 2) + assert array[0, 0] == 0.0 + assert array[1, 0] == 10.0 + assert array[2, 1] == 21.0 + + @pytest.mark.skip(reason="Phase 3 feature - boolean indexing not yet implemented") + def test_boolean_indexing_simple(self): + """Test boolean indexing on slices.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for i, p in enumerate(cp.profiles_1d): + p.electrons.density = np.array([float(i)] * 5) + + result = cp.profiles_1d[:].electrons.density + + mask = np.array([True, False, True, False, True]) + filtered = result[mask] + assert len(filtered) == 3 + + def test_assignment_on_slice(self): + """Test assignment through slices.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + # This requires assignment support + # cp.profiles_1d[:].grid.rho_tor_norm[:] = new_values + # For now, verify slicing works for reading + + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + assert array.shape == (3, 3) + + def test_xarray_integration_compatible(self): + """Test that output is compatible with xarray.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + cp.time = np.array([1.0, 2.0, 3.0]) + + for i, p in enumerate(cp.profiles_1d): + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + p.electrons.temperature = np.array([1.0, 2.0, 3.0]) * (i + 1) + + # Test that we can extract values in xarray-compatible format + temps = cp.profiles_1d[:].electrons.temperature.to_array() + times = cp.time + + assert temps.shape == (3, 3) + assert len(times) == 3 + + def test_performance_large_hierarchy(self): + """Test performance with large nested hierarchies.""" + cp = IDSFactory("3.39.0").core_profiles() + n_profiles = 50 + cp.profiles_1d.resize(n_profiles) + + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.linspace(0, 1, 100) + p.ion.resize(5) + for i in p.ion: + i.element.resize(3) + + # Should handle large data without significant slowdown + result = cp.profiles_1d[:].grid.rho_tor_norm + array = result.to_array() + + assert array.shape == (n_profiles, 100) + + def test_lazy_loading_with_multidim(self): + """Test that lazy loading works with multi-dimensional slicing.""" + # This would require a database, so we'll test with in-memory + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + + # Verify lazy attributes are preserved + assert hasattr(result, "_lazy") + assert hasattr(result, "_parent_array") + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_slice(self): + """Test slicing that results in empty arrays.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(5) + for p in cp.profiles_1d: + p.ion.resize(0) + + result = cp.profiles_1d[:].ion + assert len(result) == 5 + for ions in result: + # Each should be empty + pass + + def test_single_element_2d(self): + """Test 2D extraction with single element.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(1) + cp.profiles_1d[0].grid.rho_tor_norm = np.array([0.0, 0.5, 1.0]) + + result = cp.profiles_1d[:].grid.rho_tor_norm + assert result.shape == (1, 3) + + def test_single_dimension_value(self): + """Test accessing a single value in multi-dimensional structure.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(3) + for p in cp.profiles_1d: + p.ion.resize(2) + for i in p.ion: + i.z_ion = 1.0 + + result = cp.profiles_1d[:].ion[0].z_ion + + # Should be 3 items (one per profile) + assert len(result) == 3 + + def test_slice_of_slice(self): + """Test slicing a slice.""" + cp = IDSFactory("3.39.0").core_profiles() + cp.profiles_1d.resize(10) + for p in cp.profiles_1d: + p.ion.resize(3) + + result1 = cp.profiles_1d[::2].ion # Every other profile's ions + assert result1.shape == (5, 3) + + result2 = result1[:2] # First 2 from each + assert result2.shape == (5, 2)