Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 84 additions & 61 deletions linopy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,9 @@ class Constraint(ConstraintBase):
constraint grid (including masked/empty rows).
rhs : np.ndarray
Shape (n_flat,). Right-hand-side values.
sign : str
Constraint sign: one of '=', '<=', '>='.
Note: per-element signs are not supported (documented regression vs MutableConstraint).
sign : str or np.ndarray
Constraint sign. Either a single str ('=', '<=', '>=') for uniform
signs, or a per-row np.ndarray of sign strings for mixed signs.
coords : list of pd.Index
One index per coordinate dimension defining the constraint grid.
model : Model
Expand Down Expand Up @@ -529,7 +529,7 @@ def __init__(
csr: scipy.sparse.csr_array,
con_labels: np.ndarray,
rhs: np.ndarray,
sign: str,
sign: str | np.ndarray,
coords: list[pd.Index],
model: Model,
name: str = "",
Expand Down Expand Up @@ -613,16 +613,19 @@ def nterm(self) -> int:
def coord_names(self) -> list[str]:
return [str(c.name) for c in self._coords]

def _active_to_dataarray(
self, active_values: np.ndarray, fill: float | int | str = -1
) -> DataArray:
full = np.full(self.full_size, fill, dtype=active_values.dtype)
full[self.active_positions] = active_values
return DataArray(full.reshape(self.shape), coords=self._coords)

@property
def labels(self) -> DataArray:
"""Get labels DataArray, shape (*coord_dims)."""
if self._cindex is None:
return DataArray([])
shape = self.shape
full_size = self.full_size
labels_flat = np.full(full_size, -1, dtype=np.int64)
labels_flat[self.active_positions] = self._con_labels
return DataArray(labels_flat.reshape(shape), coords=self._coords)
return self._active_to_dataarray(self._con_labels, fill=-1)

@property
def coeffs(self) -> DataArray:
Expand All @@ -648,16 +651,39 @@ def vars(self) -> DataArray:

@property
def sign(self) -> DataArray:
"""Get sign DataArray (scalar, same sign for all entries)."""
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
"""Get sign DataArray."""
if isinstance(self._sign, str):
return DataArray(np.full(self.shape, self._sign), coords=self._coords)
return self._active_to_dataarray(self._sign, fill="")

@property
def rhs(self) -> DataArray:
"""Get RHS DataArray, shape (*coord_dims)."""
shape = self.shape
rhs_full = np.full(self.full_size, np.nan)
rhs_full[self.active_positions] = self._rhs
return DataArray(rhs_full.reshape(shape), coords=self._coords)
return self._active_to_dataarray(self._rhs, fill=np.nan)

@rhs.setter
def rhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
self._refreeze_after(lambda mc: setattr(mc, "rhs", value))

@property
def lhs(self) -> expressions.LinearExpression:
"""Get LHS as LinearExpression (triggers Dataset reconstruction)."""
return self.mutable().lhs

@lhs.setter
def lhs(self, value: ExpressionLike | VariableLike | ConstantLike) -> None:
self._refreeze_after(lambda mc: setattr(mc, "lhs", value))

def _refreeze_after(self, mutate: Callable[[MutableConstraint], None]) -> None:
mc = self.mutable()
mutate(mc)
refrozen = Constraint.from_mutable(mc, self._cindex)
self._csr = refrozen._csr
self._con_labels = refrozen._con_labels
self._rhs = refrozen._rhs
self._sign = refrozen._sign
self._coords = refrozen._coords
self._dual = None

@property
@has_optimized_model
Expand All @@ -667,9 +693,7 @@ def dual(self) -> DataArray:
raise AttributeError(
"Underlying is optimized but does not have dual values stored."
)
dual_full = np.full(self.full_size, np.nan)
dual_full[self.active_positions] = self._dual
return DataArray(dual_full.reshape(self.shape), coords=self._coords)
return self._active_to_dataarray(self._dual, fill=np.nan)

@dual.setter
def dual(self, value: DataArray) -> None:
Expand Down Expand Up @@ -731,24 +755,10 @@ def _to_dataset(self, nterm: int) -> Dataset:
def data(self) -> Dataset:
"""Reconstruct the xarray Dataset from the CSR representation."""
ds = self._to_dataset(self.nterm)
shape = self.shape
active_pos = self.active_positions
rhs_full = np.full(self.full_size, np.nan)
rhs_full[active_pos] = self._rhs
ds = ds.assign(
sign=DataArray(np.full(shape, self._sign), coords=self._coords),
rhs=DataArray(rhs_full.reshape(shape), coords=self._coords),
)
ds = ds.assign(sign=self.sign, rhs=self.rhs)
if self._dual is not None:
dual_full = np.full(self.full_size, np.nan)
dual_full[active_pos] = self._dual
ds = ds.assign(
dual=DataArray(dual_full.reshape(shape), coords=self._coords)
)
attrs: dict[str, Any] = {"name": self._name}
if self._cindex is not None:
attrs["label_range"] = (self._cindex, self._cindex + self.full_size)
return ds.assign_attrs(attrs)
ds = ds.assign(dual=self._active_to_dataarray(self._dual, fill=np.nan))
return ds.assign_attrs(self.attrs)

def __repr__(self) -> str:
"""Print the constraint without reconstructing the full Dataset."""
Expand Down Expand Up @@ -777,7 +787,8 @@ def row_expr(row: int) -> str:
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
vars_row[: end - start] = csr.indices[start:end]
coeffs_row[: end - start] = csr.data[start:end]
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
sign = self._sign if isinstance(self._sign, str) else self._sign[row]
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[sign]} {self._rhs[row]}"

if size > 1:
for indices in generate_indices_for_printout(shape, max_lines):
Expand Down Expand Up @@ -819,21 +830,22 @@ def to_netcdf_ds(self) -> Dataset:
"rhs": DataArray(self._rhs, dims=["_flat"]),
"_con_labels": DataArray(self._con_labels, dims=["_flat"]),
}
if isinstance(self._sign, np.ndarray):
data_vars["_sign"] = DataArray(self._sign, dims=["_flat"])
data_vars.update(coords_to_dataset_vars(self._coords))
if self._dual is not None:
data_vars["dual"] = DataArray(self._dual, dims=["_flat"])
dim_names = [c.name for c in self._coords]
return Dataset(
data_vars,
attrs={
"_linopy_format": "csr",
"sign": self._sign,
"cindex": self._cindex if self._cindex is not None else -1,
"shape": list(csr.shape),
"coord_dims": dim_names,
"name": self._name,
},
)
attrs: dict[str, Any] = {
"_linopy_format": "csr",
"cindex": self._cindex if self._cindex is not None else -1,
"shape": list(csr.shape),
"coord_dims": dim_names,
"name": self._name,
}
if isinstance(self._sign, str):
attrs["sign"] = self._sign
return Dataset(data_vars, attrs=attrs)

@classmethod
def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
Expand All @@ -845,7 +857,7 @@ def from_netcdf_ds(cls, ds: Dataset, model: Model, name: str) -> Constraint:
shape=shape,
)
rhs = ds["rhs"].values
sign = attrs["sign"]
sign: str | np.ndarray = ds["_sign"].values if "_sign" in ds else attrs["sign"]
_cindex_raw = int(attrs["cindex"])
cindex: int | None = _cindex_raw if _cindex_raw >= 0 else None
coord_dims = attrs["coord_dims"]
Expand Down Expand Up @@ -873,7 +885,10 @@ def to_matrix_with_rhs(
self, label_index: VariableLabelIndex
) -> tuple[scipy.sparse.csr_array, np.ndarray, np.ndarray, np.ndarray]:
"""Return (csr, con_labels, b, sense) — all pre-stored, no recomputation."""
sense = np.full(len(self._rhs), self._sign[0])
if isinstance(self._sign, str):
sense = np.full(len(self._rhs), self._sign[0])
else:
sense = np.array([s[0] for s in self._sign])
return self._csr, self._con_labels, self._rhs, sense

def sanitize_zeros(self) -> Constraint:
Expand All @@ -888,18 +903,25 @@ def sanitize_missings(self) -> Constraint:

def sanitize_infinities(self) -> Constraint:
"""Mask out rows with invalid infinite RHS values (mutates in-place)."""
if self._sign == LESS_EQUAL:
invalid = self._rhs == np.inf
elif self._sign == GREATER_EQUAL:
invalid = self._rhs == -np.inf
if isinstance(self._sign, str):
if self._sign == LESS_EQUAL:
invalid = self._rhs == np.inf
elif self._sign == GREATER_EQUAL:
invalid = self._rhs == -np.inf
else:
return self
else:
return self
invalid = ((self._sign == LESS_EQUAL) & (self._rhs == np.inf)) | (
(self._sign == GREATER_EQUAL) & (self._rhs == -np.inf)
)
if not invalid.any():
return self
keep = ~invalid
self._csr = self._csr[keep]
self._con_labels = self._con_labels[keep]
self._rhs = self._rhs[keep]
if not isinstance(self._sign, str):
self._sign = self._sign[keep]
return self

def freeze(self) -> Constraint:
Expand Down Expand Up @@ -939,13 +961,14 @@ def from_mutable(
active_mask = (labels_flat != -1) & (vars_flat != -1).any(axis=1)
rhs = con.rhs.values.ravel()[active_mask]
sign_vals = con.sign.values.ravel()
unique_signs = np.unique(sign_vals[active_mask])
if len(unique_signs) > 1:
raise ValueError(
"Constraint has per-element signs; cannot freeze to immutable Constraint. "
"This is a known limitation — use MutableConstraint instead."
)
sign = str(unique_signs[0]) if len(unique_signs) == 1 else "="
active_signs = sign_vals[active_mask]
unique_signs = np.unique(active_signs)
if len(unique_signs) == 0:
sign: str | np.ndarray = "="
elif len(unique_signs) == 1:
sign = str(unique_signs[0])
else:
sign = active_signs
dual = (
con.data["dual"].values.ravel()[active_mask] if "dual" in con.data else None
)
Expand Down
Loading