Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/plexosdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from .utils import (
apply_scenario_tags,
create_membership_record,
get_system_object_name,
insert_property_texts,
insert_property_values,
no_space,
normalize_names,
plan_property_inserts,
resolve_membership_id,
)
from .xml_handler import XMLHandler

Expand Down Expand Up @@ -695,7 +697,8 @@ def add_object(

if not collection_enum:
collection_enum = get_default_collection(class_enum)
_ = self.add_membership(ClassEnum.System, class_enum, "System", name, collection_enum)
system_name = get_system_object_name(self)
_ = self.add_membership(ClassEnum.System, class_enum, system_name, name, collection_enum)
return object_id

def add_objects(
Expand Down Expand Up @@ -756,7 +759,8 @@ def add_objects(
collection_enum = get_default_collection(class_enum)
object_ids = self.get_objects_id(names, class_enum=class_enum)
parent_class_id = self.get_class_id(ClassEnum.System)
parent_object_id = self.get_object_id(ClassEnum.System, "System")
system_name = get_system_object_name(self)
parent_object_id = self.get_object_id(ClassEnum.System, system_name)
collection_id = self.get_collection_id(
collection_enum, parent_class_enum=ClassEnum.System, child_class_enum=class_enum
)
Expand Down Expand Up @@ -1013,7 +1017,8 @@ def add_property(
Class enumeration of the parent object. If None, defaults to ClassEnum.System,
by default None
parent_object_name : str | None, optional
Name of the parent object. If None, defaults to "System", by default None
Name of the parent object. If None, membership is resolved from
`parent_class_enum`, `collection_enum`, and `object_name`.

Returns
-------
Expand Down Expand Up @@ -1079,7 +1084,14 @@ def add_property(
child_class_enum=object_class_enum,
collection_enum=collection_enum,
)
membership_id = self.get_membership_id(parent_object_name or "System", object_name, collection_enum)
membership_id = resolve_membership_id(
self,
object_name,
object_class=object_class_enum,
collection=collection_enum,
parent_class=parent_class_enum,
parent_object_name=parent_object_name,
)

query = f"INSERT INTO {Schema.Data.name}(membership_id, property_id, value) values (?, ?, ?)"
_ = self._db.execute(query, (membership_id, property_id, value))
Expand Down Expand Up @@ -1848,8 +1860,9 @@ def copy_object(
)

system_collection = get_default_collection(object_class)
old_sys_id = self.get_membership_id("System", original_object_name, system_collection)
new_sys_id = self.get_membership_id("System", new_object_name, system_collection)
system_name = get_system_object_name(self)
old_sys_id = self.get_membership_id(system_name, original_object_name, system_collection)
new_sys_id = self.get_membership_id(system_name, new_object_name, system_collection)
membership_mapping[old_sys_id] = new_sys_id

if not copy_properties:
Expand Down Expand Up @@ -2119,6 +2132,7 @@ def delete_property(
property_name: str,
collection: CollectionEnum | None = None,
parent_class: ClassEnum | None = None,
parent_object_name: str | None = None,
scenario: str | None = None,
) -> None:
"""Delete a property from an object.
Expand All @@ -2141,6 +2155,9 @@ def delete_property(
parent_class : ClassEnum | None, optional
Parent class enumeration for the property. If None, defaults to
ClassEnum.System, by default None
parent_object_name : str | None, optional
Parent object name. If None, membership is resolved by parent class,
child class, collection, and child object name.
scenario : str | None, optional
Name of the scenario to filter by. If specified, only deletes
property data associated with this scenario, by default None
Expand Down Expand Up @@ -2194,9 +2211,14 @@ def delete_property(
collection_enum=collection,
)

# For parent object name, default to "System" if not specified
parent_object_name = "System" # This matches the pattern used in add_property
membership_id = self.get_membership_id(parent_object_name, object_name, collection)
membership_id = resolve_membership_id(
self,
object_name,
object_class=object_class,
collection=collection,
parent_class=parent_class,
parent_object_name=parent_object_name,
)

# Build the delete query
if scenario is not None:
Expand Down
174 changes: 163 additions & 11 deletions src/plexosdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

from loguru import logger

from .enums import ClassEnum
from .exceptions import NotFoundError

if TYPE_CHECKING:
from plexosdb import ClassEnum, CollectionEnum, PlexosDB
from plexosdb import CollectionEnum, PlexosDB
from plexosdb.db_manager import SQLiteManager


Expand Down Expand Up @@ -282,7 +283,13 @@ def plan_property_inserts(
collection, parent_class_enum=parent_class, child_class_enum=object_class
)
collection_properties = _fetch_collection_properties(db, collection_id=collection_id)
name_to_membership = _resolve_membership_map(db, normalized_records, object_class=object_class)
name_to_membership = _resolve_membership_map(
db,
normalized_records,
object_class=object_class,
parent_class=parent_class,
collection=collection,
)
property_id_map = {prop: pid for prop, pid in collection_properties}

params, metadata_map = _build_property_rows(
Expand All @@ -304,17 +311,46 @@ def _resolve_membership_map(
normalized_records: list[dict[str, Any]],
*,
object_class: ClassEnum,
parent_class: ClassEnum,
collection: CollectionEnum,
) -> dict[str, int]:
"""Resolve membership ids for each object name."""
component_names = tuple({d["name"] for d in normalized_records if d.get("name") is not None})
try:
memberships = db.get_memberships_system(component_names, object_class=object_class)
except Exception as exc:
missing = ", ".join(sorted(name for name in component_names if name))
raise NotFoundError(
f"Objects not found: {missing}. Add them with `add_object` or `add_objects` before "
"adding properties."
) from exc
if not component_names:
return {}

memberships: list[tuple[str, int]] | list[dict[str, Any]]
if parent_class == ClassEnum.System:
try:
memberships = db.get_memberships_system(
component_names,
object_class=object_class,
collection=collection,
)
except Exception as exc:
missing = ", ".join(sorted(name for name in component_names if name))
raise NotFoundError(
f"Objects not found: {missing}. Add them with `add_object` or `add_objects` before "
"adding properties."
) from exc
else:
collection_id = db.get_collection_id(
collection, parent_class_enum=parent_class, child_class_enum=object_class
)
parent_class_id = db.get_class_id(parent_class)
child_class_id = db.get_class_id(object_class)
placeholders = ",".join("?" for _ in component_names)
query = f"""
SELECT child_object.name, mem.membership_id
FROM t_membership AS mem
INNER JOIN t_object AS child_object ON child_object.object_id = mem.child_object_id
WHERE mem.parent_class_id = ?
AND mem.child_class_id = ?
AND mem.collection_id = ?
AND child_object.name IN ({placeholders})
"""
params: tuple[Any, ...] = (parent_class_id, child_class_id, collection_id, *component_names)
memberships = db._db.fetchall(query, params)

if not memberships:
missing = ", ".join(sorted(name for name in component_names if name))
Expand All @@ -323,7 +359,123 @@ def _resolve_membership_map(
"adding properties."
)

return {membership["name"]: membership["membership_id"] for membership in memberships}
name_to_membership: dict[str, int] = {}
ambiguous_objects: set[str] = set()
for membership in memberships:
if isinstance(membership, dict):
object_name = membership["name"]
membership_id = membership["membership_id"]
else:
object_name = membership[0]
membership_id = membership[1]
existing_membership_id = name_to_membership.get(object_name)
if existing_membership_id is not None and existing_membership_id != membership_id:
ambiguous_objects.add(object_name)
continue
name_to_membership[object_name] = membership_id

if ambiguous_objects:
ambiguous_names = ", ".join(sorted(ambiguous_objects))
raise ValueError(
"Multiple memberships found for objects in the same parent class/collection: "
f"{ambiguous_names}. Resolve membership ambiguity before bulk insert."
)

logger.trace("Resolved {} memberships for collection {}", len(name_to_membership), collection.value)
return name_to_membership


def get_system_object_name(db: PlexosDB) -> str:
"""Return the canonical System object name for this model.

Parameters
----------
db : PlexosDB
Database instance to query.

Returns
-------
str
The name of the System object.

Raises
------
NotFoundError
If no System object exists in the model.
ValueError
If multiple System objects exist and none is named "System".
"""
system_objects = db.list_objects_by_class(ClassEnum.System)
if not system_objects:
raise NotFoundError("No System object found in the model.")
if len(system_objects) == 1:
logger.trace("Resolved System object: {}", system_objects[0])
return system_objects[0]
if "System" in system_objects:
logger.trace("Multiple System objects found, defaulting to 'System'")
return "System"
raise ValueError(
"Multiple System objects found and no default could be inferred. "
"Pass an explicit `parent_object_name`."
)


def resolve_membership_id(
db: PlexosDB,
object_name: str,
*,
object_class: ClassEnum,
collection: CollectionEnum,
parent_class: ClassEnum,
parent_object_name: str | None = None,
) -> int:
"""Resolve a single membership ID for property operations.

Parameters
----------
db : PlexosDB
Database instance to query.
object_name : str
Name of the child object.
object_class : ClassEnum
Class of the child object.
collection : CollectionEnum
Collection defining the membership relationship.
parent_class : ClassEnum
Class of the parent object.
parent_object_name : str | None, optional
Explicit parent object name. If provided, uses direct lookup.

Returns
-------
int
The membership ID.

Raises
------
NotFoundError
If no matching membership is found.
ValueError
If multiple memberships exist (ambiguous).
"""
if parent_object_name is not None:
logger.trace("Resolving membership via explicit parent: {}", parent_object_name)
return db.get_membership_id(parent_object_name, object_name, collection)

mapping = _resolve_membership_map(
db,
[{"name": object_name}],
object_class=object_class,
parent_class=parent_class,
collection=collection,
)
if object_name not in mapping:
raise NotFoundError(
f"No membership found for '{object_name}' in collection '{collection.value}' "
f"with parent class '{parent_class.value}'."
)
logger.trace("Resolved membership_id={} for {}", mapping[object_name], object_name)
return mapping[object_name]


def _build_property_rows(
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def db_instance():
"""Create a base DB instance that lasts the entire test session."""
db = PlexosDB()
yield db
db._db.close()


@pytest.fixture()
Expand All @@ -63,6 +64,7 @@ def db_instance_with_xml(data_folder, tmp_path):
shutil.copy(xml_fname, xml_copy)
db = PlexosDB.from_xml(xml_path=xml_copy)
yield db
db._db.close()
xml_copy.unlink()


Expand Down Expand Up @@ -152,6 +154,7 @@ def db_instance_with_schema() -> PlexosDB:
"INSERT INTO t_property_report(property_id, collection_id, name) VALUES (1, 1, 'Units')"
)
yield db
db._db.close()


@pytest.fixture(scope="function")
Expand Down
5 changes: 5 additions & 0 deletions tests/test_db_manager_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def test_iter_dicts_reraises_sqlite_error() -> None:
from plexosdb.db_manager import SQLiteManager

db = SQLiteManager()
original_conn = db._con
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.execute.side_effect = sqlite3.Error("Query failed")
Expand All @@ -313,12 +314,15 @@ def test_iter_dicts_reraises_sqlite_error() -> None:
with pytest.raises(sqlite3.Error):
list(db.iter_dicts("SELECT * FROM nonexistent"))

original_conn.close()


def test_iter_dicts_cursor_cleanup_on_error() -> None:
"""Test that cursor is closed even when error occurs during iteration."""
from plexosdb.db_manager import SQLiteManager

db = SQLiteManager()
original_conn = db._con
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.execute.side_effect = sqlite3.Error("Query failed")
Expand All @@ -330,6 +334,7 @@ def test_iter_dicts_cursor_cleanup_on_error() -> None:
list(db.iter_dicts("SELECT * FROM test"))

mock_cursor.close.assert_called_once()
original_conn.close()


def test_fetchmany_happy_path(db_with_large_dataset: SQLiteManager) -> None:
Expand Down
Loading