diff --git a/src/typedal/cli.py b/src/typedal/cli.py index b0ae463..467a0ce 100644 --- a/src/typedal/cli.py +++ b/src/typedal/cli.py @@ -390,7 +390,8 @@ def fake_migrations( previously_migrated = ( db( - db.ewh_implemented_features.name.belongs(to_fake) & (db.ewh_implemented_features.installed == True) # noqa E712 + db.ewh_implemented_features.name.belongs(to_fake) + & (db.ewh_implemented_features.installed == True) # noqa E712 ) .select(db.ewh_implemented_features.name) .column("name") diff --git a/src/typedal/fields.py b/src/typedal/fields.py index abe2394..6a478b1 100644 --- a/src/typedal/fields.py +++ b/src/typedal/fields.py @@ -373,7 +373,9 @@ def UploadField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]: Upload = UploadField -def ReferenceField[T_subclass: (TypedTable, Table)]( +def ReferenceField[ + T_subclass: (TypedTable, Table) +]( other_table: str | t.Type[TypedTable] | TypedTable | Table | T_subclass, **kw: t.Unpack[FieldSettings], ) -> TypedField[int]: diff --git a/src/typedal/query_builder.py b/src/typedal/query_builder.py index ff59b78..5dcc9fb 100644 --- a/src/typedal/query_builder.py +++ b/src/typedal/query_builder.py @@ -33,6 +33,7 @@ OnQuery, OrderBy, Query, + Row, Rows, SelectKwargs, T_MetaInstance, @@ -517,7 +518,11 @@ def _collect(self) -> str: """ return self.to_sql() - def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | None": + def _collect_cached( + self, + metadata: Metadata, + into: t.Type[_TypedTable], + ) -> "TypedRows[T_MetaInstance] | None": expires_at = metadata["cache"].get("expires_at") metadata["cache"] |= { # key is partly dependant on cache metadata but not these: @@ -529,6 +534,7 @@ def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | No _, key = create_and_hash_cache_key( self.model, + f"{into.__module__}.{into.__qualname__}", metadata, self.query, self.select_args, @@ -566,12 +572,15 @@ def collect( verbose: bool = False, _to: t.Type["TypedRows[t.Any]"] = None, add_id: bool = True, + _into: t.Type[_TypedTable] | None = None, + _init: t.Callable[[_TypedTable, Row], None] | None = None, ) -> "TypedRows[T_MetaInstance]": """ Execute the built query and turn it into model instances, while handling relationships. """ if _to is None: _to = TypedRows + into = _into or self.model if not isinstance(self.model, TableMeta): # tried to use querybuilder with a non-typedal table, @@ -585,7 +594,7 @@ def collect( metadata: Metadata = self.metadata.copy() - if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata)): + if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata, into)): return result query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id) @@ -609,12 +618,12 @@ def collect( if not self.relationships: # easy - typed_rows = _to.from_rows(rows, self.model, metadata=metadata) + typed_rows = _to.from_rows(rows, self.model, metadata=metadata, into=into, init=_init) else: # harder: try to match rows to the belonging objects # assume structure of {'table': } per row. # if that's not the case, return default behavior again - typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to) + typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to, _into=into, _init=_init) for fn_after in db._after_collect: fn_after(self, typed_rows, rows) @@ -622,6 +631,35 @@ def collect( # only saves if requested in metadata: return save_to_cache(typed_rows, rows) + def collect_into[T_Into: _TypedTable]( + self, + into: t.Type[T_Into], + verbose: bool = False, + add_id: bool = True, + init: t.Callable[[T_Into, Row], None] | None = None, + ) -> "TypedRows[T_Into]": + """ + Execute the built query and instantiate root records as another model class. + """ + self._validate_collect_into_model(into) + _init = t.cast(t.Callable[[_TypedTable, Row], None] | None, init) + rows = self.collect(verbose=verbose, add_id=add_id, _into=into, _init=_init) + return t.cast("TypedRows[T_Into]", rows) + + def _validate_collect_into_model(self, into: t.Type[t.Any]) -> None: + if not isinstance(into, TableMeta): + raise TypeError("collect_into expects a TypedTable class") + + source = self.model._ensure_table_defined() + target = into._ensure_table_defined() + + if source is target or str(source) == str(target): + return + + raise ValueError( + f"collect_into target '{into.__name__}' must be bound to table '{source}', got '{target}'", + ) + @t.overload def column[T: t.Any](self, field: TypedField[T], **options: t.Unpack[SelectKwargs]) -> list[T]: """ @@ -843,12 +881,15 @@ def _collect_with_relationships( rows: Rows, metadata: Metadata, _to: t.Type["TypedRows[T_MetaInstance]"], + _into: t.Type[_TypedTable] | None = None, + _init: t.Callable[[_TypedTable, Row], None] | None = None, ) -> "TypedRows[T_MetaInstance]": """ Transform the raw rows into Typed Table model instances with nested relationships. """ db = self._get_db() main_table = self._ensure_table_defined() + into = _into or self.model # id: Model records: dict[t.Any, T_MetaInstance] = {} @@ -866,7 +907,9 @@ def _collect_with_relationships( raw_per_id[main_id].append(normalize_table_keys(row)) if main_id not in records: - records[main_id] = self.model(main) + records[main_id] = t.cast(T_MetaInstance, into(main)) + if _init: + _init(t.cast(_TypedTable, records[main_id]), row) records[main_id]._with = list(self.relationships.keys()) # Setup all relationship defaults (once) diff --git a/src/typedal/rows.py b/src/typedal/rows.py index 29c1719..1c82cec 100644 --- a/src/typedal/rows.py +++ b/src/typedal/rows.py @@ -62,21 +62,7 @@ def __init__( `model` is a Typed Table class """ - def _get_id(row: Row) -> int: - """ - Try to find the id field in a row. - - If _extra exists, the row changes: - - """ - if idx := getattr(row, "id", None): - return t.cast(int, idx) - elif main := getattr(row, str(model), None): - return t.cast(int, main.id) - else: # pragma: no cover - raise NotImplementedError(f"`id` could not be found for {row}") - - records = records or {_get_id(row): model(row) for row in rows} + records = records or {self._get_id(row, model): model(row) for row in rows} raw = raw or {} for idx, entity in records.items(): @@ -87,6 +73,21 @@ def _get_id(row: Row) -> int: self.metadata = metadata or {} self.colnames = rows.colnames + @staticmethod + def _get_id(row: Row, model: t.Type[t.Any]) -> int: + """ + Try to find the id field in a row. + + If _extra exists, the row changes: + + """ + if idx := getattr(row, "id", None): + return t.cast(int, idx) + elif main := getattr(row, str(model), None): + return t.cast(int, main.id) + else: # pragma: no cover + raise NotImplementedError(f"`id` could not be found for {row}") + def __len__(self) -> int: """ Return the count of rows. @@ -374,11 +375,22 @@ def from_rows( rows: Rows, model: t.Type[T_MetaInstance], metadata: Metadata = None, + into: t.Type[_TypedTable] | None = None, + init: t.Callable[[_TypedTable, Row], None] | None = None, ) -> "TypedRows[T_MetaInstance]": """ Internal method to convert a Rows object to a TypedRows. """ - return cls(rows, model, metadata=metadata) + target_model = into or model + + def build(row: Row) -> T_MetaInstance: + instance = t.cast(T_MetaInstance, target_model(row)) + if init: + init(instance, row) + return instance + + records = {cls._get_id(row, model): build(row) for row in rows} + return cls(rows, model, records=records, metadata=metadata) def __getstate__(self) -> AnyDict: """ diff --git a/src/typedal/tables.py b/src/typedal/tables.py index 37a40bc..149aab2 100644 --- a/src/typedal/tables.py +++ b/src/typedal/tables.py @@ -386,6 +386,17 @@ def collect(self: t.Type[T_MetaInstance], verbose: bool = False) -> "TypedRows[T """ return QueryBuilder(self).collect(verbose=verbose) + def collect_into[T_Into: _TypedTable]( + self: t.Type[_TypedTable], + into: t.Type[T_Into], + verbose: bool = False, + init: t.Callable[[T_Into, Row], None] | None = None, + ) -> "TypedRows[T_Into]": + """ + See QueryBuilder.collect_into! + """ + return QueryBuilder(self).collect_into(into=into, verbose=verbose, init=init) + @property def ALL(cls) -> pydal.objects.SQLALL: """ @@ -1153,7 +1164,9 @@ def render(self, fields: list[Field] = None, compact: bool = False) -> t.Self: relation_row = row[relation_name] - if isinstance(relation_row, list): + if relation_row is None: + row[relation_name] = None + elif isinstance(relation_row, list): # list of rows combined = [] diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index 2ae013a..b1ad0a7 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -28,6 +28,10 @@ class TestRelationship(TypedTable): querytable: TestQueryTable +class TestQueryTableBound(TestQueryTable): + pass + + class Undefined(TypedTable): value: int @@ -210,6 +214,60 @@ def test_select(): assert other.value +def test_collect_into(): + _setup_data() + + rows = TestQueryTable.where(lambda row: row.number < 3).collect_into(TestQueryTableBound) + first = rows.first() + + assert first + assert isinstance(first, TestQueryTableBound) + assert rows.model is TestQueryTable + + joined = TestQueryTable.join("relations").where(id=1).collect_into(TestQueryTableBound) + joined_first = joined.first() + + assert joined_first + assert isinstance(joined_first, TestQueryTableBound) + assert isinstance(joined_first.relations[0], TestRelationship) + + marker = object() + + def bind(sticker: TestQueryTableBound, _row): + sticker.item = marker + + bound_rows = TestQueryTable.where(lambda row: row.number < 3).collect_into(TestQueryTableBound, init=bind) + assert all(getattr(row, "item", None) is marker for row in bound_rows) + + calls: list[int] = [] + + def bind_once_per_root(sticker: TestQueryTableBound, _row): + calls.append(sticker.id) + + TestQueryTable.join("relations").where(id=1).collect_into(TestQueryTableBound, init=bind_once_per_root) + assert calls == [1] + + with pytest.raises(TypeError): + TestQueryTable.collect_into(dict) # type: ignore[arg-type] + + with pytest.raises(ValueError): + TestQueryTable.collect_into(TestRelationship) + + +def test_collect_into_cache_isolation(): + _setup_data() + + regular = TestQueryTable.where(id=1).cache().collect() + remapped_fresh = TestQueryTable.where(id=1).cache().collect_into(TestQueryTableBound) + remapped_cached = TestQueryTable.where(id=1).cache().collect_into(TestQueryTableBound) + + assert regular.metadata["cache"]["status"] == "fresh" + assert remapped_fresh.metadata["cache"]["status"] == "fresh" + assert remapped_cached.metadata["cache"]["status"] == "cached" + + assert isinstance(remapped_cached.first(), TestQueryTableBound) + + def test_paginate(): _setup_data() diff --git a/tests/test_row.py b/tests/test_row.py index d5e1891..6f27975 100644 --- a/tests/test_row.py +++ b/tests/test_row.py @@ -307,3 +307,28 @@ class RenderTable(TypedTable): } ], } + + +def test_render_with_none_single_relationship_row(): + @db.define() + class RelatedTableNone(TypedTable): + value: str + + @db.define() + class RenderTableNone(TypedTable): + normal: str + related = relationship( + RelatedTableNone, + condition=lambda this, that: this.normal == that.value, + ) + + RenderTableNone.insert(normal="no-match") + + row = RenderTableNone.select().join("related").first() + + assert row + assert row.related is None + + rendered = row.render() + + assert rendered.related is None