Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions .github/workflows/commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,6 @@ on:
types: [opened, reopened, synchronize]

jobs:
labeler:
name: apply labels
permissions:
contents: read
pull-requests: write
issues: write
runs-on: [ubuntu-latest]
steps:
- uses: actions/checkout@v6
- uses: actions/labeler@v6.0.1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
configuration-path: .github/labeler.yaml
sync-labels: true

lint-commit-messages:
name: lint commit message
runs-on: [ubuntu-latest]
Expand Down
21 changes: 21 additions & 0 deletions .github/workflows/labeler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: labeler

on:
pull_request_target:
types: [opened, reopened, synchronize]

jobs:
labeler:
name: apply labels
permissions:
contents: read
pull-requests: write
issues: write
runs-on: [ubuntu-latest]
steps:
- uses: actions/checkout@v6
- uses: actions/labeler@v6.0.1
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
configuration-path: .github/labeler.yaml
sync-labels: true
125 changes: 91 additions & 34 deletions src/plexosdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
from pathlib import Path
from string import Template
from typing import Any, Literal, TypedDict, cast
from collections.abc import Sequence
import warnings

from loguru import logger

from .checks import check_memberships_from_records
from .db_manager import SQLiteManager
from .enums import ClassEnum, CollectionEnum, Schema, get_default_collection, str2enum
from .enums import (
ClassEnum,
CollectionEnum,
Schema,
get_default_collection,
str2enum,
parse_class_enum,
parse_collection_enum,
)
from .exceptions import (
NameError,
NoPropertiesError,
Expand Down Expand Up @@ -1835,22 +1844,19 @@ def copy_object(
category = self.query("SELECT name from t_category WHERE category_id = ?", (category_id[0][0],))
new_object_id = self.add_object(object_class, new_object_name, category=category[0][0])
membership_mapping = self.copy_object_memberships(
object_class=object_class, original_name=new_object_name, new_name=new_object_name
object_class=object_class, original_name=original_object_name, new_name=new_object_name
)

# If we do not find a membership, we just look for the system membership
if not membership_mapping:
membership_mapping = {}
system_membership_id = self.list_object_memberships(object_class, original_object_name)[0][
"membership_id"
]
new_membership_id = self.list_object_memberships(object_class, new_object_name)[0][
"membership_id"
]
membership_mapping[system_membership_id] = new_membership_id
system_collection = get_default_collection(object_class)
old_sys_id = self.get_membership_id("System", original_object_name, system_collection)
new_sys_id = self.get_membership_id("System", new_object_name, system_collection)
membership_mapping[old_sys_id] = new_sys_id

if not copy_properties:
return new_object_id

data_ids = self.get_object_data_ids(object_class, name=original_object_name)
if not data_ids and copy_properties:
if not data_ids:
logger.debug(f"No properties found for {original_object_name}")
return new_object_id

Expand All @@ -1869,22 +1875,19 @@ def copy_object_memberships(
for membership in all_memberships:
parent_name = membership["parent_name"]
child_name = membership["child_name"]
parent_class = ClassEnum[membership["parent_class_name"]]
child_class = ClassEnum[membership["child_class_name"]]
collection = CollectionEnum[membership["collection_name"]]
parent_class = parse_class_enum(membership["parent_class_name"])
child_class = parse_class_enum(membership["child_class_name"])
collection = parse_collection_enum(membership["collection_name"])

# Determine if original object was parent or child
if child_name == original_name:
# Original object is child, new object will be child
if child_class == object_class and child_name == original_name:
old_id = self.get_membership_id(parent_name, original_name, collection)
try:
new_id = self.add_membership(parent_class, child_class, parent_name, new_name, collection)
membership_mapping[old_id] = new_id
except Exception as e:
logger.warning(f"Could not create child membership: {e}")

elif parent_name == original_name:
# Original object is parent, new object will be parent
elif parent_class == object_class and parent_name == original_name:
old_id = self.get_membership_id(original_name, child_name, collection)
try:
new_id = self.add_membership(parent_class, child_class, new_name, child_name, collection)
Expand Down Expand Up @@ -1932,15 +1935,46 @@ def _copy_object_properties(self, membership_mapping: dict[int, int]) -> bool:
self._db.execute("CREATE TEMPORARY TABLE temp_data_mapping (old_id INTEGER, new_id INTEGER)")

self._db.execute("""
INSERT INTO temp_data_mapping
SELECT old_d.data_id AS old_id, new_d.data_id AS new_id
FROM t_data old_d
JOIN temp_mapping tm ON old_d.membership_id = tm.old_id
JOIN t_data new_d ON
new_d.membership_id = tm.new_id AND
new_d.property_id = old_d.property_id AND
new_d.value = old_d.value
WHERE new_d.data_id NOT IN (SELECT data_id FROM t_tag)
INSERT INTO temp_data_mapping (old_id, new_id)
WITH
old_rows AS (
SELECT
d.data_id AS old_id,
tm.new_id AS new_membership_id,
d.property_id,
d.value,
d.state,
ROW_NUMBER() OVER (
PARTITION BY d.membership_id, d.property_id, d.value, d.state
ORDER BY d.data_id
) AS rn
FROM t_data d
JOIN temp_mapping tm ON d.membership_id = tm.old_id
),
new_rows AS (
SELECT
d.data_id AS new_id,
d.membership_id AS new_membership_id,
d.property_id,
d.value,
d.state,
ROW_NUMBER() OVER (
PARTITION BY d.membership_id, d.property_id, d.value, d.state
ORDER BY d.data_id
) AS rn
FROM t_data d
WHERE d.membership_id IN (SELECT new_id FROM temp_mapping)
)
SELECT
o.old_id,
n.new_id
FROM old_rows o
JOIN new_rows n
ON n.new_membership_id = o.new_membership_id
AND n.property_id = o.property_id
AND n.value = o.value
AND n.state IS o.state
AND n.rn = o.rn
""")

# Copy tags using data ID mapping
Expand Down Expand Up @@ -2824,7 +2858,9 @@ def get_object_data_ids(
d.data_id
"""
result = self._db.query(query, tuple(params))
assert result
# assert result
if not result:
return []
return [row[0] for row in result]

def get_object_properties(
Expand Down Expand Up @@ -3627,9 +3663,30 @@ def list_objects_by_class(self, class_enum: ClassEnum, /, *, category: str | Non
['Generator1', 'Generator2']
"""
class_id = self.get_class_id(class_enum)
query = f"SELECT name from {Schema.Objects.name} WHERE class_id = ?"
result = self._db.query(query, (class_id,))
return [d[0] for d in result]

params: Sequence[Any]
if category is None:
query = f"SELECT name FROM {Schema.Objects.name} WHERE class_id = ? ORDER BY name"
params = (class_id,)
else:
if not self.check_category_exists(class_enum, category):
msg = f"Category '{category}' does not exist for class {class_enum}."
raise NotFoundError(msg)

query = f"""
SELECT obj.name
FROM {Schema.Objects.name} AS obj
JOIN {Schema.Categories.name} AS cat
ON obj.category_id = cat.category_id
WHERE obj.class_id = ?
AND cat.name = ?
ORDER BY obj.name
"""
params = (class_id, category)

result = self._db.query(query, params)
assert result is not None
return [row[0] for row in result]

def list_parent_objects(
self,
Expand Down
31 changes: 31 additions & 0 deletions src/plexosdb/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,34 @@ def get_default_collection(class_enum: ClassEnum) -> CollectionEnum:
if collection_name not in CollectionEnum.__members__:
collection_name = class_enum.name
return CollectionEnum[collection_name]


def _parse_str_enum(enum_cls: type[Enum], value: str | Enum) -> Enum:
"""Parse a string or Enum to an Enum instance of the specified enum class."""
if isinstance(value, enum_cls):
return value

# Exact value match
for e in enum_cls:
if e.value == value:
return e

# Enum name without spaces
if isinstance(value, str):
key = value.replace(" ", "")
try:
return enum_cls[key]
except KeyError:
raise ValueError(f"{value!r} is not a valid {enum_cls.__name__}")
else:
raise ValueError(f"{value!r} is not a valid {enum_cls.__name__}")


def parse_class_enum(value: str | ClassEnum) -> ClassEnum:
"""Parse a string or ClassEnum to a ClassEnum instance."""
return cast(ClassEnum, _parse_str_enum(ClassEnum, value))


def parse_collection_enum(value: str | CollectionEnum) -> CollectionEnum:
"""Parse a string or CollectionEnum to a CollectionEnum instance."""
return cast(CollectionEnum, _parse_str_enum(CollectionEnum, value))
42 changes: 42 additions & 0 deletions tests/test_enums_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,45 @@ def test_get_default_collection_for_supported_classes():
for class_member in supported_classes:
result = get_default_collection(class_member)
assert isinstance(result, CollectionEnum)


def test_schema_enum_name_and_label_properties():
"""Test Schema enum name and label properties for all members."""
from plexosdb.enums import Schema

for member in Schema:
assert isinstance(member.name, str)
assert member.label is None or isinstance(member.label, str)


def test_parse_str_enum_exact_value_and_spaces():
"""Test _parse_str_enum for exact value and name with spaces."""
from plexosdb.enums import _parse_str_enum, ClassEnum

assert _parse_str_enum(ClassEnum, "Generator") == ClassEnum.Generator
assert _parse_str_enum(ClassEnum, "Data File") == ClassEnum.DataFile
assert _parse_str_enum(ClassEnum, "DataFile") == ClassEnum.DataFile


def test_parse_str_enum_invalid_value_raises():
"""Test _parse_str_enum raises ValueError for invalid value."""
from plexosdb.enums import _parse_str_enum, ClassEnum

with pytest.raises(ValueError):
_parse_str_enum(ClassEnum, "NotAClass")


def test_parse_class_enum_and_collection_enum():
"""Test parse_class_enum and parse_collection_enum utility functions."""
from plexosdb.enums import parse_class_enum, parse_collection_enum, ClassEnum, CollectionEnum

assert parse_class_enum("Generator") == ClassEnum.Generator
assert parse_class_enum(ClassEnum.Generator) == ClassEnum.Generator

assert parse_collection_enum("Generators") == CollectionEnum.Generators
assert parse_collection_enum(CollectionEnum.Generators) == CollectionEnum.Generators

with pytest.raises(ValueError):
parse_class_enum("NotAClass")
with pytest.raises(ValueError):
parse_collection_enum("NotACollection")