Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions linopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
import linopy.monkey_patch_xarray # noqa: F401
from linopy.common import align
from linopy.config import options
from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL
from linopy.constraints import Constraint, Constraints
from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL, PerformanceWarning
from linopy.constraints import (
Constraint,
ConstraintBase,
Constraints,
MutableConstraint,
)
from linopy.expressions import LinearExpression, QuadraticExpression, merge
from linopy.io import read_netcdf
from linopy.model import Model, Variable, Variables, available_solvers
Expand All @@ -30,8 +35,11 @@

__all__ = (
"Constraint",
"ConstraintBase",
"Constraints",
"MutableConstraint",
"EQUAL",
"PerformanceWarning",
"GREATER_EQUAL",
"LESS_EQUAL",
"LinearExpression",
Expand Down
139 changes: 123 additions & 16 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import operator
import os
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import partial, reduce, wraps
from functools import cached_property, partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from warnings import warn
Expand All @@ -19,6 +19,7 @@
import pandas as pd
import polars as pl
from numpy import arange, signedinteger
from polars.datatypes import DataTypeClass
from xarray import DataArray, Dataset, apply_ufunc, broadcast
from xarray import align as xr_align
from xarray.core import dtypes, indexing
Expand All @@ -40,7 +41,7 @@
)

if TYPE_CHECKING:
from linopy.constraints import Constraint
from linopy.constraints import ConstraintBase
from linopy.expressions import LinearExpression, QuadraticExpression
from linopy.variables import Variable

Expand Down Expand Up @@ -327,7 +328,7 @@ def check_has_nulls(df: pd.DataFrame, name: str) -> None:
raise ValueError(f"Fields {name} contains nan's in field(s) {fields}")


def infer_schema_polars(ds: Dataset) -> dict[Hashable, pl.DataType]:
def infer_schema_polars(ds: Dataset) -> dict[str, DataTypeClass]:
"""
Infer the polars data schema from a xarray dataset.

Expand All @@ -339,21 +340,22 @@ def infer_schema_polars(ds: Dataset) -> dict[Hashable, pl.DataType]:
-------
dict: A dictionary mapping column names to their corresponding Polars data types.
"""
schema = {}
schema: dict[str, DataTypeClass] = {}
np_major_version = int(np.__version__.split(".")[0])
use_int32 = os.name == "nt" and np_major_version < 2
for name, array in ds.items():
name = str(name)
if np.issubdtype(array.dtype, np.integer):
schema[name] = pl.Int32 if use_int32 else pl.Int64
elif np.issubdtype(array.dtype, np.floating):
schema[name] = pl.Float64 # type: ignore
schema[name] = pl.Float64
elif np.issubdtype(array.dtype, np.bool_):
schema[name] = pl.Boolean # type: ignore
schema[name] = pl.Boolean
elif np.issubdtype(array.dtype, np.object_):
schema[name] = pl.Object # type: ignore
schema[name] = pl.Object
else:
schema[name] = pl.Utf8 # type: ignore
return schema # type: ignore
schema[name] = pl.Utf8
return schema


def to_polars(ds: Dataset, **kwargs: Any) -> pl.DataFrame:
Expand Down Expand Up @@ -429,7 +431,7 @@ def filter_nulls_polars(df: pl.DataFrame) -> pl.DataFrame:
if "labels" in df.columns:
cond.append(pl.col("labels").ne(-1))

cond = reduce(operator.and_, cond) # type: ignore
cond = reduce(operator.and_, cond) # type: ignore[arg-type]
return df.filter(cond)


Expand Down Expand Up @@ -554,7 +556,7 @@ def fill_missing_coords(
return ds


T = TypeVar("T", Dataset, "Variable", "LinearExpression", "Constraint")
T = TypeVar("T", Dataset, "Variable", "LinearExpression", "ConstraintBase")


@overload
Expand Down Expand Up @@ -583,10 +585,10 @@ def iterate_slices(

@overload
def iterate_slices(
ds: Constraint,
ds: ConstraintBase,
slice_size: int | None = 10_000,
slice_dims: list | None = None,
) -> Generator[Constraint, None, None]: ...
) -> Generator[ConstraintBase, None, None]: ...


def iterate_slices(
Expand Down Expand Up @@ -655,7 +657,7 @@ def iterate_slices(
start = i * chunk_size
end = min(start + chunk_size, size_of_leading_dim)
slice_dict = {leading_dim: slice(start, end)}
yield ds.isel(slice_dict)
yield ds.isel(slice_dict) # type: ignore[attr-defined]


def _remap(array: np.ndarray, mapping: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -939,6 +941,57 @@ def find_single(value: int) -> tuple[str, dict] | tuple[None, None]:
raise ValueError("Array's with more than two dimensions is not supported")


class VariableLabelIndex:
"""
Index for O(1) mapping between variable labels and dense positions.

Both arrays are computed lazily and cached:
- ``vlabels``: active variable labels in encounter order, shape (n_active_vars,)
- ``label_to_pos``: derived from vlabels; size _xCounter, maps label -> position (-1 if masked)

Invalidated by clearing the instance ``__dict__`` when variables are added or removed.
"""

def __init__(self, variables: Any) -> None:
self._variables = variables

@cached_property
def vlabels(self) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a docstring would be nice, as such numpy outputs are hard to interpret

"""Active variable labels in encounter order, shape (n_active_vars,)."""
label_lists = []
for _, var in self._variables.items():
labels = var.labels.values.ravel()
mask = labels != -1
label_lists.append(labels[mask])
return (
np.concatenate(label_lists) if label_lists else np.array([], dtype=np.intp)
)

@cached_property
def label_to_pos(self) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a docstring would be nice, as such numpy outputs are hard to interpret

"""
Mapping from variable label to dense position, shape (_xCounter,).

Position i in the active variable array corresponds to label vlabels[i].
Masked or unused labels map to -1.
"""
vlabels = self.vlabels
n = self._variables.model._xCounter
label_to_pos = np.full(n, -1, dtype=np.intp)
label_to_pos[vlabels] = np.arange(len(vlabels), dtype=np.intp)
return label_to_pos

@property
def n_active_vars(self) -> int:
"""Number of active (non-masked) variables."""
return len(self.vlabels)

def invalidate(self) -> None:
"""Clear cached arrays so they are recomputed on next access."""
self.__dict__.pop("vlabels", None)
self.__dict__.pop("label_to_pos", None)


def get_label_position(
obj: Any,
values: int | np.ndarray,
Expand Down Expand Up @@ -1306,7 +1359,7 @@ def align(
"Variable",
"LinearExpression",
"QuadraticExpression",
"Constraint",
"ConstraintBase",
)


Expand All @@ -1324,7 +1377,7 @@ def __getitem__(
# expand the indexer so we can handle Ellipsis
labels = indexing.expanded_indexer(key, self.object.ndim)
key = dict(zip(self.object.dims, labels))
return self.object.sel(key)
return self.object.sel(key) # type: ignore[attr-defined]


class EmptyDeprecationWrapper:
Expand Down Expand Up @@ -1358,6 +1411,60 @@ def __call__(self) -> bool:
return self.value


def coords_to_dataset_vars(coords: list[pd.Index]) -> dict[str, DataArray]:
"""
Serialize a list of pd.Index (including MultiIndex) to a DataArray dict.

Suitable for embedding coordinate metadata as plain data variables in a
Dataset that has its own unrelated dimensions (e.g. CSR netcdf format).
Reconstruct with :func:`coords_from_dataset`.
"""
data_vars: dict[str, DataArray] = {}
for c in coords:
if isinstance(c, pd.MultiIndex):
for level_name, level_values in zip(c.names, c.levels):
data_vars[f"_coord_{c.name}_level_{level_name}"] = DataArray(
np.array(level_values),
dims=[f"_coorddim_{c.name}_level_{level_name}"],
)
data_vars[f"_coord_{c.name}_codes"] = DataArray(
np.array(c.codes).T,
dims=[f"_coorddim_{c.name}", f"_coorddim_{c.name}_nlevels"],
)
else:
data_vars[f"_coord_{c.name}"] = DataArray(
np.array(c), dims=[f"_coorddim_{c.name}"]
)
return data_vars


def coords_from_dataset(ds: Dataset, coord_dims: list[str]) -> list[pd.Index]:
"""
Deserialize a list of pd.Index (including MultiIndex) from a Dataset.

Reconstructs coordinates previously serialized by :func:`coords_to_dataset_vars`.
"""
coords = []
for d in coord_dims:
if f"_coord_{d}_codes" in ds:
codes_2d = ds[f"_coord_{d}_codes"].values.T
level_names = [
str(k)[len(f"_coord_{d}_level_") :]
for k in ds
if str(k).startswith(f"_coord_{d}_level_")
]
arrays = [
ds[f"_coord_{d}_level_{ln}"].values[codes_2d[i]]
for i, ln in enumerate(level_names)
]
mi = pd.MultiIndex.from_arrays(arrays, names=level_names)
mi.name = d
coords.append(mi)
else:
coords.append(pd.Index(ds[f"_coord_{d}"].values, name=d))
return coords


def is_constant(x: SideLike) -> bool:
"""
Check if the given object is a constant type or an expression type without
Expand Down
11 changes: 9 additions & 2 deletions linopy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
GREATER_EQUAL = ">="
LESS_EQUAL = "<="


class PerformanceWarning(UserWarning):
"""Warning raised when an operation triggers expensive Dataset reconstruction."""


long_EQUAL = "=="
short_GREATER_EQUAL = ">"
short_LESS_EQUAL = "<"
Expand Down Expand Up @@ -211,9 +216,11 @@ def process(cls, status: str, termination_condition: str) -> "Status":

@classmethod
def from_termination_condition(
cls, termination_condition: Union["TerminationCondition", str]
cls, termination_condition: Union["TerminationCondition", str, None]
) -> "Status":
termination_condition = TerminationCondition.process(termination_condition)
termination_condition = TerminationCondition.process(
termination_condition if termination_condition is not None else "unknown"
)
solver_status = SolverStatus.from_termination_condition(termination_condition)
return cls(solver_status, termination_condition)

Expand Down
Loading
Loading