Skip to content

Commit 01add96

Browse files
committed
fix: review fixes for PR #630 (matrix accessor rewrite)
- Fix __repr__ passing CSR positions instead of variable labels - Fix set_blocks failing on frozen Constraint - Extract _active_to_dataarray helper to reduce DRY violations - Simplify reset_dual to direct mutation instead of reconstruction - Add tests for freeze/mutable roundtrip, VariableLabelIndex, to_matrix_with_rhs, from_mutable mixed signs, repr correctness
1 parent 19125ac commit 01add96

2 files changed

Lines changed: 104 additions & 43 deletions

File tree

linopy/constraints.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -595,16 +595,19 @@ def nterm(self) -> int:
595595
def coord_names(self) -> list[str]:
596596
return [c.name for c in self._coords]
597597

598+
def _active_to_dataarray(
599+
self, active_values: np.ndarray, fill: float | int = -1
600+
) -> DataArray:
601+
full = np.full(self.full_size, fill, dtype=active_values.dtype)
602+
full[self.active_positions] = active_values
603+
return DataArray(full.reshape(self.shape), coords=self._coords)
604+
598605
@property
599606
def labels(self) -> DataArray:
600607
"""Get labels DataArray, shape (*coord_dims)."""
601608
if self._cindex is None:
602609
return DataArray([])
603-
shape = self.shape
604-
full_size = self.full_size
605-
labels_flat = np.full(full_size, -1, dtype=np.int64)
606-
labels_flat[self.active_positions] = self._con_labels
607-
return DataArray(labels_flat.reshape(shape), coords=self._coords)
610+
return self._active_to_dataarray(self._con_labels, fill=-1)
608611

609612
@property
610613
def coeffs(self) -> DataArray:
@@ -636,10 +639,7 @@ def sign(self) -> DataArray:
636639
@property
637640
def rhs(self) -> DataArray:
638641
"""Get RHS DataArray, shape (*coord_dims)."""
639-
shape = self.shape
640-
rhs_full = np.full(self.full_size, np.nan)
641-
rhs_full[self.active_positions] = self._rhs
642-
return DataArray(rhs_full.reshape(shape), coords=self._coords)
642+
return self._active_to_dataarray(self._rhs, fill=np.nan)
643643

644644
@property
645645
@has_optimized_model
@@ -649,9 +649,7 @@ def dual(self) -> DataArray:
649649
raise AttributeError(
650650
"Underlying is optimized but does not have dual values stored."
651651
)
652-
dual_full = np.full(self.full_size, np.nan)
653-
dual_full[self.active_positions] = self._dual
654-
return DataArray(dual_full.reshape(self.shape), coords=self._coords)
652+
return self._active_to_dataarray(self._dual, fill=np.nan)
655653

656654
@dual.setter
657655
def dual(self, value: DataArray) -> None:
@@ -713,24 +711,10 @@ def _to_dataset(self, nterm: int) -> Dataset:
713711
def data(self) -> Dataset:
714712
"""Reconstruct the xarray Dataset from the CSR representation."""
715713
ds = self._to_dataset(self.nterm)
716-
shape = self.shape
717-
active_pos = self.active_positions
718-
rhs_full = np.full(self.full_size, np.nan)
719-
rhs_full[active_pos] = self._rhs
720-
ds = ds.assign(
721-
sign=DataArray(np.full(shape, self._sign), coords=self._coords),
722-
rhs=DataArray(rhs_full.reshape(shape), coords=self._coords),
723-
)
714+
ds = ds.assign(sign=self.sign, rhs=self.rhs)
724715
if self._dual is not None:
725-
dual_full = np.full(self.full_size, np.nan)
726-
dual_full[active_pos] = self._dual
727-
ds = ds.assign(
728-
dual=DataArray(dual_full.reshape(shape), coords=self._coords)
729-
)
730-
attrs: dict[str, Any] = {"name": self._name}
731-
if self._cindex is not None:
732-
attrs["label_range"] = (self._cindex, self._cindex + self.full_size)
733-
return ds.assign_attrs(attrs)
716+
ds = ds.assign(dual=self._active_to_dataarray(self._dual, fill=np.nan))
717+
return ds.assign_attrs(self.attrs)
734718

735719
def __repr__(self) -> str:
736720
"""Print the constraint without reconstructing the full Dataset."""
@@ -753,11 +737,13 @@ def __repr__(self) -> str:
753737
header_string = f"{self.type} `{self._name}`" if self._name else f"{self.type}"
754738
lines = []
755739

740+
vlabels = self._model.variables.label_index.vlabels
741+
756742
def row_expr(row: int) -> str:
757743
start, end = int(csr.indptr[row]), int(csr.indptr[row + 1])
758744
vars_row = np.full(nterm, -1, dtype=np.int64)
759745
coeffs_row = np.zeros(nterm, dtype=csr.dtype)
760-
vars_row[: end - start] = csr.indices[start:end]
746+
vars_row[: end - start] = vlabels[csr.indices[start:end]]
761747
coeffs_row[: end - start] = csr.data[start:end]
762748
return f"{print_single_expression(coeffs_row, vars_row, 0, self._model)} {SIGNS_pretty[self._sign]} {self._rhs[row]}"
763749

@@ -1610,7 +1596,12 @@ def set_blocks(self, block_map: np.ndarray) -> None:
16101596

16111597
res = res.where(not_missing.any(constraint.term_dim), -1)
16121598
res = res.where(not_zero.any(constraint.term_dim), 0)
1613-
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
1599+
if isinstance(constraint, MutableConstraint):
1600+
constraint._data = assign_multiindex_safe(constraint.data, blocks=res)
1601+
else:
1602+
mc = constraint.mutable()
1603+
mc._data = assign_multiindex_safe(mc.data, blocks=res)
1604+
self.data[name] = Constraint.from_mutable(mc, constraint._cindex)
16141605

16151606
@property
16161607
def flat(self) -> pd.DataFrame:
@@ -1667,18 +1658,7 @@ def reset_dual(self) -> None:
16671658
"""
16681659
for k, c in self.items():
16691660
if isinstance(c, Constraint):
1670-
if c._dual is not None:
1671-
self.data[k] = Constraint(
1672-
c._csr,
1673-
c._con_labels,
1674-
c._rhs,
1675-
c._sign,
1676-
c._coords,
1677-
c._model,
1678-
c._name,
1679-
cindex=c._cindex,
1680-
dual=None,
1681-
)
1661+
c._dual = None
16821662
else:
16831663
if "dual" in c.data:
16841664
c._data = c.data.drop_vars("dual")

test/test_constraint.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,84 @@ def test_constraints_inequalities(m: Model) -> None:
706706

707707
def test_constraints_equalities(m: Model) -> None:
708708
assert isinstance(m.constraints.equalities, Constraints)
709+
710+
711+
def test_freeze_mutable_roundtrip(m: Model) -> None:
712+
frozen = m.constraints["c"]
713+
assert isinstance(frozen, Constraint)
714+
mc = frozen.mutable()
715+
assert isinstance(mc, MutableConstraint)
716+
refrozen = Constraint.from_mutable(mc, frozen._cindex)
717+
assert_equal(frozen.labels, refrozen.labels)
718+
assert_equal(frozen.rhs, refrozen.rhs)
719+
assert_equal(frozen.sign, refrozen.sign)
720+
np.testing.assert_array_equal(frozen._csr.toarray(), refrozen._csr.toarray())
721+
np.testing.assert_array_equal(frozen._con_labels, refrozen._con_labels)
722+
723+
724+
def test_freeze_mutable_roundtrip_with_masking() -> None:
725+
m = Model()
726+
x = m.add_variables(coords=[pd.RangeIndex(5, name="i")], name="x")
727+
mask = xr.DataArray([True, False, True, False, True], dims=["i"])
728+
m.add_constraints(x.where(mask) >= 0, name="c")
729+
frozen = m.constraints["c"]
730+
mc = frozen.mutable()
731+
refrozen = Constraint.from_mutable(mc, frozen._cindex)
732+
assert_equal(frozen.labels, refrozen.labels)
733+
assert_equal(frozen.rhs, refrozen.rhs)
734+
assert frozen.ncons == refrozen.ncons == 3
735+
736+
737+
def test_from_mutable_mixed_signs_raises() -> None:
738+
m = Model()
739+
x = m.add_variables(coords=[pd.RangeIndex(3, name="i")], name="x")
740+
m.add_constraints(x >= 0, name="mixed", freeze=False)
741+
mc = m.constraints["mixed"]
742+
assert isinstance(mc, MutableConstraint)
743+
mc._data["sign"] = xr.DataArray(["<=", ">=", "<="], dims=["i"])
744+
with pytest.raises(ValueError, match="per-element signs"):
745+
Constraint.from_mutable(mc)
746+
747+
748+
def test_variable_label_index(m: Model) -> None:
749+
li = m.variables.label_index
750+
assert li.n_active_vars > 0
751+
assert len(li.vlabels) == li.n_active_vars
752+
assert li.label_to_pos.shape[0] == m._xCounter
753+
for lbl in li.vlabels:
754+
assert li.label_to_pos[lbl] >= 0
755+
assert (li.label_to_pos[li.vlabels] == np.arange(li.n_active_vars)).all()
756+
757+
758+
def test_variable_label_index_invalidation(m: Model) -> None:
759+
li = m.variables.label_index
760+
old_vlabels = li.vlabels.copy()
761+
m.add_variables(name="w")
762+
li.invalidate()
763+
assert len(li.vlabels) > len(old_vlabels)
764+
765+
766+
def test_to_matrix_with_rhs(m: Model) -> None:
767+
c = m.constraints["c"]
768+
li = m.variables.label_index
769+
csr, con_labels, b, sense = c.to_matrix_with_rhs(li)
770+
assert csr.shape[0] == len(con_labels)
771+
assert csr.shape[0] == len(b)
772+
assert csr.shape[0] == len(sense)
773+
assert all(s in ("<", ">", "=") for s in sense)
774+
np.testing.assert_array_equal(b, c._rhs)
775+
776+
777+
def test_to_matrix_with_rhs_mutable(m: Model) -> None:
778+
mc = m.constraints["c"].mutable()
779+
li = m.variables.label_index
780+
csr, con_labels, b, sense = mc.to_matrix_with_rhs(li)
781+
assert csr.shape[0] == len(con_labels)
782+
assert csr.shape[0] == len(b)
783+
assert csr.shape[0] == len(sense)
784+
785+
786+
def test_constraint_repr_shows_variable_names(m: Model) -> None:
787+
c = m.constraints["c"]
788+
r = repr(c)
789+
assert "x" in r

0 commit comments

Comments
 (0)