Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/typedal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def fake_migrations(

previously_migrated = (
db(
db.ewh_implemented_features.name.belongs(to_fake) & (db.ewh_implemented_features.installed == True) # noqa E712
db.ewh_implemented_features.name.belongs(to_fake)
& (db.ewh_implemented_features.installed == True) # noqa E712
)
.select(db.ewh_implemented_features.name)
.column("name")
Expand Down
4 changes: 3 additions & 1 deletion src/typedal/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,9 @@ def UploadField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]:
Upload = UploadField


def ReferenceField[T_subclass: (TypedTable, Table)](
def ReferenceField[
T_subclass: (TypedTable, Table)
](
other_table: str | t.Type[TypedTable] | TypedTable | Table | T_subclass,
**kw: t.Unpack[FieldSettings],
) -> TypedField[int]:
Expand Down
53 changes: 48 additions & 5 deletions src/typedal/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
OnQuery,
OrderBy,
Query,
Row,
Rows,
SelectKwargs,
T_MetaInstance,
Expand Down Expand Up @@ -517,7 +518,11 @@ def _collect(self) -> str:
"""
return self.to_sql()

def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | None":
def _collect_cached(
self,
metadata: Metadata,
into: t.Type[_TypedTable],
) -> "TypedRows[T_MetaInstance] | None":
expires_at = metadata["cache"].get("expires_at")
metadata["cache"] |= {
# key is partly dependant on cache metadata but not these:
Expand All @@ -529,6 +534,7 @@ def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | No

_, key = create_and_hash_cache_key(
self.model,
f"{into.__module__}.{into.__qualname__}",
metadata,
self.query,
self.select_args,
Expand Down Expand Up @@ -566,12 +572,15 @@ def collect(
verbose: bool = False,
_to: t.Type["TypedRows[t.Any]"] = None,
add_id: bool = True,
_into: t.Type[_TypedTable] | None = None,
_init: t.Callable[[_TypedTable, Row], None] | None = None,
) -> "TypedRows[T_MetaInstance]":
"""
Execute the built query and turn it into model instances, while handling relationships.
"""
if _to is None:
_to = TypedRows
into = _into or self.model

if not isinstance(self.model, TableMeta):
# tried to use querybuilder with a non-typedal table,
Expand All @@ -585,7 +594,7 @@ def collect(

metadata: Metadata = self.metadata.copy()

if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata)):
if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata, into)):
return result

query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id)
Expand All @@ -609,19 +618,48 @@ def collect(

if not self.relationships:
# easy
typed_rows = _to.from_rows(rows, self.model, metadata=metadata)
typed_rows = _to.from_rows(rows, self.model, metadata=metadata, into=into, init=_init)
else:
# harder: try to match rows to the belonging objects
# assume structure of {'table': <data>} per row.
# if that's not the case, return default behavior again
typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to)
typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to, _into=into, _init=_init)

for fn_after in db._after_collect:
fn_after(self, typed_rows, rows)

# only saves if requested in metadata:
return save_to_cache(typed_rows, rows)

def collect_into[T_Into: _TypedTable](
self,
into: t.Type[T_Into],
verbose: bool = False,
add_id: bool = True,
init: t.Callable[[T_Into, Row], None] | None = None,
) -> "TypedRows[T_Into]":
"""
Execute the built query and instantiate root records as another model class.
"""
self._validate_collect_into_model(into)
_init = t.cast(t.Callable[[_TypedTable, Row], None] | None, init)
rows = self.collect(verbose=verbose, add_id=add_id, _into=into, _init=_init)
return t.cast("TypedRows[T_Into]", rows)

def _validate_collect_into_model(self, into: t.Type[t.Any]) -> None:
if not isinstance(into, TableMeta):
raise TypeError("collect_into expects a TypedTable class")

source = self.model._ensure_table_defined()
target = into._ensure_table_defined()

if source is target or str(source) == str(target):
return

raise ValueError(
f"collect_into target '{into.__name__}' must be bound to table '{source}', got '{target}'",
)

@t.overload
def column[T: t.Any](self, field: TypedField[T], **options: t.Unpack[SelectKwargs]) -> list[T]:
"""
Expand Down Expand Up @@ -843,12 +881,15 @@ def _collect_with_relationships(
rows: Rows,
metadata: Metadata,
_to: t.Type["TypedRows[T_MetaInstance]"],
_into: t.Type[_TypedTable] | None = None,
_init: t.Callable[[_TypedTable, Row], None] | None = None,
) -> "TypedRows[T_MetaInstance]":
"""
Transform the raw rows into Typed Table model instances with nested relationships.
"""
db = self._get_db()
main_table = self._ensure_table_defined()
into = _into or self.model

# id: Model
records: dict[t.Any, T_MetaInstance] = {}
Expand All @@ -866,7 +907,9 @@ def _collect_with_relationships(
raw_per_id[main_id].append(normalize_table_keys(row))

if main_id not in records:
records[main_id] = self.model(main)
records[main_id] = t.cast(T_MetaInstance, into(main))
if _init:
_init(t.cast(_TypedTable, records[main_id]), row)
records[main_id]._with = list(self.relationships.keys())

# Setup all relationship defaults (once)
Expand Down
44 changes: 28 additions & 16 deletions src/typedal/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,7 @@ def __init__(
`model` is a Typed Table class
"""

def _get_id(row: Row) -> int:
"""
Try to find the id field in a row.

If _extra exists, the row changes:
<Row {'test_relationship': {'id': 1}, '_extra': {'COUNT("test_relationship"."querytable")': 8}}>
"""
if idx := getattr(row, "id", None):
return t.cast(int, idx)
elif main := getattr(row, str(model), None):
return t.cast(int, main.id)
else: # pragma: no cover
raise NotImplementedError(f"`id` could not be found for {row}")

records = records or {_get_id(row): model(row) for row in rows}
records = records or {self._get_id(row, model): model(row) for row in rows}
raw = raw or {}

for idx, entity in records.items():
Expand All @@ -87,6 +73,21 @@ def _get_id(row: Row) -> int:
self.metadata = metadata or {}
self.colnames = rows.colnames

@staticmethod
def _get_id(row: Row, model: t.Type[t.Any]) -> int:
"""
Try to find the id field in a row.

If _extra exists, the row changes:
<Row {'test_relationship': {'id': 1}, '_extra': {'COUNT("test_relationship"."querytable")': 8}}>
"""
if idx := getattr(row, "id", None):
return t.cast(int, idx)
elif main := getattr(row, str(model), None):
return t.cast(int, main.id)
else: # pragma: no cover
raise NotImplementedError(f"`id` could not be found for {row}")

def __len__(self) -> int:
"""
Return the count of rows.
Expand Down Expand Up @@ -374,11 +375,22 @@ def from_rows(
rows: Rows,
model: t.Type[T_MetaInstance],
metadata: Metadata = None,
into: t.Type[_TypedTable] | None = None,
init: t.Callable[[_TypedTable, Row], None] | None = None,
) -> "TypedRows[T_MetaInstance]":
"""
Internal method to convert a Rows object to a TypedRows.
"""
return cls(rows, model, metadata=metadata)
target_model = into or model

def build(row: Row) -> T_MetaInstance:
instance = t.cast(T_MetaInstance, target_model(row))
if init:
init(instance, row)
return instance

records = {cls._get_id(row, model): build(row) for row in rows}
return cls(rows, model, records=records, metadata=metadata)

def __getstate__(self) -> AnyDict:
"""
Expand Down
15 changes: 14 additions & 1 deletion src/typedal/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,17 @@ def collect(self: t.Type[T_MetaInstance], verbose: bool = False) -> "TypedRows[T
"""
return QueryBuilder(self).collect(verbose=verbose)

def collect_into[T_Into: _TypedTable](
self: t.Type[_TypedTable],
into: t.Type[T_Into],
verbose: bool = False,
init: t.Callable[[T_Into, Row], None] | None = None,
) -> "TypedRows[T_Into]":
"""
See QueryBuilder.collect_into!
"""
return QueryBuilder(self).collect_into(into=into, verbose=verbose, init=init)

@property
def ALL(cls) -> pydal.objects.SQLALL:
"""
Expand Down Expand Up @@ -1153,7 +1164,9 @@ def render(self, fields: list[Field] = None, compact: bool = False) -> t.Self:

relation_row = row[relation_name]

if isinstance(relation_row, list):
if relation_row is None:
row[relation_name] = None
elif isinstance(relation_row, list):
# list of rows
combined = []

Expand Down
58 changes: 58 additions & 0 deletions tests/test_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class TestRelationship(TypedTable):
querytable: TestQueryTable


class TestQueryTableBound(TestQueryTable):
pass


class Undefined(TypedTable):
value: int

Expand Down Expand Up @@ -210,6 +214,60 @@ def test_select():
assert other.value


def test_collect_into():
_setup_data()

rows = TestQueryTable.where(lambda row: row.number < 3).collect_into(TestQueryTableBound)
first = rows.first()

assert first
assert isinstance(first, TestQueryTableBound)
assert rows.model is TestQueryTable

joined = TestQueryTable.join("relations").where(id=1).collect_into(TestQueryTableBound)
joined_first = joined.first()

assert joined_first
assert isinstance(joined_first, TestQueryTableBound)
assert isinstance(joined_first.relations[0], TestRelationship)

marker = object()

def bind(sticker: TestQueryTableBound, _row):
sticker.item = marker

bound_rows = TestQueryTable.where(lambda row: row.number < 3).collect_into(TestQueryTableBound, init=bind)
assert all(getattr(row, "item", None) is marker for row in bound_rows)

calls: list[int] = []

def bind_once_per_root(sticker: TestQueryTableBound, _row):
calls.append(sticker.id)

TestQueryTable.join("relations").where(id=1).collect_into(TestQueryTableBound, init=bind_once_per_root)
assert calls == [1]

with pytest.raises(TypeError):
TestQueryTable.collect_into(dict) # type: ignore[arg-type]

with pytest.raises(ValueError):
TestQueryTable.collect_into(TestRelationship)


def test_collect_into_cache_isolation():
_setup_data()

regular = TestQueryTable.where(id=1).cache().collect()
remapped_fresh = TestQueryTable.where(id=1).cache().collect_into(TestQueryTableBound)
remapped_cached = TestQueryTable.where(id=1).cache().collect_into(TestQueryTableBound)

assert regular.metadata["cache"]["status"] == "fresh"
assert remapped_fresh.metadata["cache"]["status"] == "fresh"
assert remapped_cached.metadata["cache"]["status"] == "cached"

assert isinstance(remapped_cached.first(), TestQueryTableBound)


def test_paginate():
_setup_data()

Expand Down
25 changes: 25 additions & 0 deletions tests/test_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,28 @@ class RenderTable(TypedTable):
}
],
}


def test_render_with_none_single_relationship_row():
@db.define()
class RelatedTableNone(TypedTable):
value: str

@db.define()
class RenderTableNone(TypedTable):
normal: str
related = relationship(
RelatedTableNone,
condition=lambda this, that: this.normal == that.value,
)

RenderTableNone.insert(normal="no-match")

row = RenderTableNone.select().join("related").first()

assert row
assert row.related is None

rendered = row.render()

assert rendered.related is None
Loading