diff --git a/.github/workflows/commit.yaml b/.github/workflows/commit.yaml index 38b2d16..941a51f 100644 --- a/.github/workflows/commit.yaml +++ b/.github/workflows/commit.yaml @@ -5,21 +5,6 @@ on: types: [opened, reopened, synchronize] jobs: - labeler: - name: apply labels - permissions: - contents: read - pull-requests: write - issues: write - runs-on: [ubuntu-latest] - steps: - - uses: actions/checkout@v6 - - uses: actions/labeler@v6.0.1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - configuration-path: .github/labeler.yaml - sync-labels: true - lint-commit-messages: name: lint commit message runs-on: [ubuntu-latest] diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml new file mode 100644 index 0000000..46e6446 --- /dev/null +++ b/.github/workflows/labeler.yaml @@ -0,0 +1,21 @@ +name: labeler + +on: + pull_request_target: + types: [opened, reopened, synchronize] + +jobs: + labeler: + name: apply labels + permissions: + contents: read + pull-requests: write + issues: write + runs-on: [ubuntu-latest] + steps: + - uses: actions/checkout@v6 + - uses: actions/labeler@v6.0.1 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + configuration-path: .github/labeler.yaml + sync-labels: true diff --git a/src/plexosdb/db.py b/src/plexosdb/db.py index 19461b8..86d48a4 100644 --- a/src/plexosdb/db.py +++ b/src/plexosdb/db.py @@ -8,13 +8,22 @@ from pathlib import Path from string import Template from typing import Any, Literal, TypedDict, cast +from collections.abc import Sequence import warnings from loguru import logger from .checks import check_memberships_from_records from .db_manager import SQLiteManager -from .enums import ClassEnum, CollectionEnum, Schema, get_default_collection, str2enum +from .enums import ( + ClassEnum, + CollectionEnum, + Schema, + get_default_collection, + str2enum, + parse_class_enum, + parse_collection_enum, +) from .exceptions import ( NameError, NoPropertiesError, @@ -1835,22 +1844,19 @@ def copy_object( category = self.query("SELECT name from t_category WHERE category_id = ?", (category_id[0][0],)) new_object_id = self.add_object(object_class, new_object_name, category=category[0][0]) membership_mapping = self.copy_object_memberships( - object_class=object_class, original_name=new_object_name, new_name=new_object_name + object_class=object_class, original_name=original_object_name, new_name=new_object_name ) - # If we do not find a membership, we just look for the system membership - if not membership_mapping: - membership_mapping = {} - system_membership_id = self.list_object_memberships(object_class, original_object_name)[0][ - "membership_id" - ] - new_membership_id = self.list_object_memberships(object_class, new_object_name)[0][ - "membership_id" - ] - membership_mapping[system_membership_id] = new_membership_id + system_collection = get_default_collection(object_class) + old_sys_id = self.get_membership_id("System", original_object_name, system_collection) + new_sys_id = self.get_membership_id("System", new_object_name, system_collection) + membership_mapping[old_sys_id] = new_sys_id + + if not copy_properties: + return new_object_id data_ids = self.get_object_data_ids(object_class, name=original_object_name) - if not data_ids and copy_properties: + if not data_ids: logger.debug(f"No properties found for {original_object_name}") return new_object_id @@ -1869,13 +1875,11 @@ def copy_object_memberships( for membership in all_memberships: parent_name = membership["parent_name"] child_name = membership["child_name"] - parent_class = ClassEnum[membership["parent_class_name"]] - child_class = ClassEnum[membership["child_class_name"]] - collection = CollectionEnum[membership["collection_name"]] + parent_class = parse_class_enum(membership["parent_class_name"]) + child_class = parse_class_enum(membership["child_class_name"]) + collection = parse_collection_enum(membership["collection_name"]) - # Determine if original object was parent or child - if child_name == original_name: - # Original object is child, new object will be child + if child_class == object_class and child_name == original_name: old_id = self.get_membership_id(parent_name, original_name, collection) try: new_id = self.add_membership(parent_class, child_class, parent_name, new_name, collection) @@ -1883,8 +1887,7 @@ def copy_object_memberships( except Exception as e: logger.warning(f"Could not create child membership: {e}") - elif parent_name == original_name: - # Original object is parent, new object will be parent + elif parent_class == object_class and parent_name == original_name: old_id = self.get_membership_id(original_name, child_name, collection) try: new_id = self.add_membership(parent_class, child_class, new_name, child_name, collection) @@ -1932,15 +1935,46 @@ def _copy_object_properties(self, membership_mapping: dict[int, int]) -> bool: self._db.execute("CREATE TEMPORARY TABLE temp_data_mapping (old_id INTEGER, new_id INTEGER)") self._db.execute(""" - INSERT INTO temp_data_mapping - SELECT old_d.data_id AS old_id, new_d.data_id AS new_id - FROM t_data old_d - JOIN temp_mapping tm ON old_d.membership_id = tm.old_id - JOIN t_data new_d ON - new_d.membership_id = tm.new_id AND - new_d.property_id = old_d.property_id AND - new_d.value = old_d.value - WHERE new_d.data_id NOT IN (SELECT data_id FROM t_tag) + INSERT INTO temp_data_mapping (old_id, new_id) + WITH + old_rows AS ( + SELECT + d.data_id AS old_id, + tm.new_id AS new_membership_id, + d.property_id, + d.value, + d.state, + ROW_NUMBER() OVER ( + PARTITION BY d.membership_id, d.property_id, d.value, d.state + ORDER BY d.data_id + ) AS rn + FROM t_data d + JOIN temp_mapping tm ON d.membership_id = tm.old_id + ), + new_rows AS ( + SELECT + d.data_id AS new_id, + d.membership_id AS new_membership_id, + d.property_id, + d.value, + d.state, + ROW_NUMBER() OVER ( + PARTITION BY d.membership_id, d.property_id, d.value, d.state + ORDER BY d.data_id + ) AS rn + FROM t_data d + WHERE d.membership_id IN (SELECT new_id FROM temp_mapping) + ) + SELECT + o.old_id, + n.new_id + FROM old_rows o + JOIN new_rows n + ON n.new_membership_id = o.new_membership_id + AND n.property_id = o.property_id + AND n.value = o.value + AND n.state IS o.state + AND n.rn = o.rn """) # Copy tags using data ID mapping @@ -2824,7 +2858,9 @@ def get_object_data_ids( d.data_id """ result = self._db.query(query, tuple(params)) - assert result + # assert result + if not result: + return [] return [row[0] for row in result] def get_object_properties( @@ -3627,9 +3663,30 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non ['Generator1', 'Generator2'] """ class_id = self.get_class_id(class_enum) - query = f"SELECT name from {Schema.Objects.name} WHERE class_id = ?" - result = self._db.query(query, (class_id,)) - return [d[0] for d in result] + + params: Sequence[Any] + if category is None: + query = f"SELECT name FROM {Schema.Objects.name} WHERE class_id = ? ORDER BY name" + params = (class_id,) + else: + if not self.check_category_exists(class_enum, category): + msg = f"Category '{category}' does not exist for class {class_enum}." + raise NotFoundError(msg) + + query = f""" + SELECT obj.name + FROM {Schema.Objects.name} AS obj + JOIN {Schema.Categories.name} AS cat + ON obj.category_id = cat.category_id + WHERE obj.class_id = ? + AND cat.name = ? + ORDER BY obj.name + """ + params = (class_id, category) + + result = self._db.query(query, params) + assert result is not None + return [row[0] for row in result] def list_parent_objects( self, diff --git a/src/plexosdb/enums.py b/src/plexosdb/enums.py index d4042a4..56380cd 100644 --- a/src/plexosdb/enums.py +++ b/src/plexosdb/enums.py @@ -158,3 +158,34 @@ def get_default_collection(class_enum: ClassEnum) -> CollectionEnum: if collection_name not in CollectionEnum.__members__: collection_name = class_enum.name return CollectionEnum[collection_name] + + +def _parse_str_enum(enum_cls: type[Enum], value: str | Enum) -> Enum: + """Parse a string or Enum to an Enum instance of the specified enum class.""" + if isinstance(value, enum_cls): + return value + + # Exact value match + for e in enum_cls: + if e.value == value: + return e + + # Enum name without spaces + if isinstance(value, str): + key = value.replace(" ", "") + try: + return enum_cls[key] + except KeyError: + raise ValueError(f"{value!r} is not a valid {enum_cls.__name__}") + else: + raise ValueError(f"{value!r} is not a valid {enum_cls.__name__}") + + +def parse_class_enum(value: str | ClassEnum) -> ClassEnum: + """Parse a string or ClassEnum to a ClassEnum instance.""" + return cast(ClassEnum, _parse_str_enum(ClassEnum, value)) + + +def parse_collection_enum(value: str | CollectionEnum) -> CollectionEnum: + """Parse a string or CollectionEnum to a CollectionEnum instance.""" + return cast(CollectionEnum, _parse_str_enum(CollectionEnum, value)) diff --git a/tests/test_enums_functions.py b/tests/test_enums_functions.py index 7594f96..2ec3653 100644 --- a/tests/test_enums_functions.py +++ b/tests/test_enums_functions.py @@ -359,3 +359,45 @@ def test_get_default_collection_for_supported_classes(): for class_member in supported_classes: result = get_default_collection(class_member) assert isinstance(result, CollectionEnum) + + +def test_schema_enum_name_and_label_properties(): + """Test Schema enum name and label properties for all members.""" + from plexosdb.enums import Schema + + for member in Schema: + assert isinstance(member.name, str) + assert member.label is None or isinstance(member.label, str) + + +def test_parse_str_enum_exact_value_and_spaces(): + """Test _parse_str_enum for exact value and name with spaces.""" + from plexosdb.enums import _parse_str_enum, ClassEnum + + assert _parse_str_enum(ClassEnum, "Generator") == ClassEnum.Generator + assert _parse_str_enum(ClassEnum, "Data File") == ClassEnum.DataFile + assert _parse_str_enum(ClassEnum, "DataFile") == ClassEnum.DataFile + + +def test_parse_str_enum_invalid_value_raises(): + """Test _parse_str_enum raises ValueError for invalid value.""" + from plexosdb.enums import _parse_str_enum, ClassEnum + + with pytest.raises(ValueError): + _parse_str_enum(ClassEnum, "NotAClass") + + +def test_parse_class_enum_and_collection_enum(): + """Test parse_class_enum and parse_collection_enum utility functions.""" + from plexosdb.enums import parse_class_enum, parse_collection_enum, ClassEnum, CollectionEnum + + assert parse_class_enum("Generator") == ClassEnum.Generator + assert parse_class_enum(ClassEnum.Generator) == ClassEnum.Generator + + assert parse_collection_enum("Generators") == CollectionEnum.Generators + assert parse_collection_enum(CollectionEnum.Generators) == CollectionEnum.Generators + + with pytest.raises(ValueError): + parse_class_enum("NotAClass") + with pytest.raises(ValueError): + parse_collection_enum("NotACollection")