diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ddb25cd..8b31786 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -69,7 +69,7 @@ jobs:
strategy:
matrix:
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
- example_app: ["todos"]
+ example_app: ["todos", "blog"]
services:
mongodb:
diff --git a/.gitignore b/.gitignore
index 2f02596..f67c6d6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -135,6 +135,7 @@ venv/
ENV/
env.bak/
venv.bak/
+.venv3_13
# Spyder project settings
.spyderproject
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a3fb163..f6648cd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
## [Unreleased]
+### Changed
+
+- Allowed `EmbeddedJsonModel`'s to also contain `EmbeddedJsonModel` if there is need.
+- Allowed extra key-word args to be passed to the `SQLModel()`
+- Made the defining of fields by calling `Field()` or `Relationship()` mandatory.
+ This is because SQLModel's require this, and fail if this is not the case.
+- Add support for "Many-to-One" relationships in SQL `insert()` implementation
+
+### Added
+
+- Added the [examples/blog](./examples/blog) example
+- Add the `link_model` parameter to the `SQLModel()` function to add link models used as through tables
+
## [0.1.6] - 2025-02-19
### Changed
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e3e541b..0ab99fc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -77,7 +77,7 @@ By contributing, you agree that your contributions will be licensed under its MI
- Install the dependencies
```bash
- pip install -r requirements.txt
+ pip install -r ."[all,test]"
```
- Run the pre-commit installation
diff --git a/LIMITATIONS.md b/LIMITATIONS.md
new file mode 100644
index 0000000..f522e1b
--- /dev/null
+++ b/LIMITATIONS.md
@@ -0,0 +1,15 @@
+# Limitations
+
+## Filtering
+
+### Redis
+
+- Mongo-style regular expression filtering is not supported.
+ This is because native redis regular expression filtering is limited to the most basic text based search.
+
+## Update Operation
+
+### SQL
+
+- Even though one can update a model to theoretically infinite number of levels deep,
+ the returned results can only contain 1-level-deep nested models and no more.
diff --git a/README.md b/README.md
index 7fc459b..617fd6c 100644
--- a/README.md
+++ b/README.md
@@ -336,6 +336,11 @@ libraries = await redis_store.delete(
- [ ] Add documentation site
+## Limitations
+
+This library is limited in some specific cases.
+Read through the [`LIMITATIONS.md`](./LIMITATIONS.md) file for more.
+
## Contributions
Contributions are welcome. The docs have to maintained, the code has to be made cleaner, more idiomatic and faster,
diff --git a/examples/blog/.gitignore b/examples/blog/.gitignore
new file mode 100644
index 0000000..f640b3b
--- /dev/null
+++ b/examples/blog/.gitignore
@@ -0,0 +1,188 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# UV
+# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+#uv.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+.idea/
+
+# vscode
+.vscode/
+
+# Ruff stuff:
+.ruff_cache/
+
+# PyPI configuration file
+.pypirc
+
+# environments
+senv/
+renv/
+menv/
+
+# pyenv
+.python-version
+
+# sqlite test db
+test.db
\ No newline at end of file
diff --git a/examples/blog/LICENSE b/examples/blog/LICENSE
new file mode 100644
index 0000000..efbf346
--- /dev/null
+++ b/examples/blog/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Martin Ahindura
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/examples/blog/README.md b/examples/blog/README.md
new file mode 100644
index 0000000..c0e72ea
--- /dev/null
+++ b/examples/blog/README.md
@@ -0,0 +1,69 @@
+# Blog
+
+A simple Blog app based on [NQLStore](https://github.com/sopherapps/nqlstore)
+
+## Requirements
+
+- [Python +3.10](https://python.org)
+- [NQLStore](https://github.com/sopherapps/nqlstore)
+- [Redis stack (optional)](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/)
+- [MongoDB (optional)](https://www.mongodb.com/products/self-managed/community-edition)
+- [SQLite (optional)](https://www.sqlite.org/)
+
+## Getting Started
+
+- Ensure you have [Python +3.10](https://python.org) installed
+
+- Copy this repository and enter this folder
+
+```shell
+git clone https://github.com/sopherapps/nqlstore.git
+cd nqlstore/examples/blog
+```
+
+- Create a virtual env, activate it and install requirements
+
+```shell
+python -m venv env
+source env/bin/activate
+pip install -r requirements.txt
+```
+
+- To use with [MongoDB](https://www.mongodb.com/try/download/community), install and start its server.
+- To use with redis, install [redis stack](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/)
+ and start its server in another terminal.
+
+- Start the application, set the URL's for the database(s) to use.
+ Options are:
+ - `SQL_URL` for [SQLite](https://www.sqlite.org/).
+ - `MONGO_URL` (required) and `MONGO_DB` (default: "todos") for [MongoDB](https://www.mongodb.com/products/self-managed/community-edition)
+ - `REDIS_URL` for [Redis](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/).
+
+ _It is possible to use multiple databases at the same time. Just set multiple environment variables_
+
+```shell
+export SQL_URL="sqlite+aiosqlite:///test.db"
+#export MONGO_URL="mongodb://localhost:27017"
+#export MONGO_DB="testing"
+export REDIS_URL="redis://localhost:6379/0"
+fastapi dev main.py
+```
+
+## License
+
+Copyright (c) 2025 [Martin Ahindura](https://github.com/Tinitto)
+Licensed under the [MIT License](./LICENSE)
+
+## Gratitude
+
+Glory be to God for His unmatchable love.
+
+> "As Jesus was on His way, the crowds almost crushed Him.
+> And a woman was there who had been subject to bleeding
+> for twelve years, but no one could heal her.
+> She came up behind Him and touched the edge of His cloak,
+> and immediately her bleeding stopped."
+>
+> -- Luke 8: 42-44
+
+
diff --git a/examples/blog/auth.py b/examples/blog/auth.py
new file mode 100644
index 0000000..08de8cd
--- /dev/null
+++ b/examples/blog/auth.py
@@ -0,0 +1,171 @@
+"""Deal with authentication for the app"""
+
+import os
+from datetime import datetime, timedelta, timezone
+from typing import Annotated, Any
+
+import jwt
+from fastapi import Depends, HTTPException, status
+from fastapi.security import OAuth2PasswordBearer
+from jwt.exceptions import InvalidTokenError
+from models import MongoInternalAuthor, RedisInternalAuthor, SqlInternalAuthor
+from passlib.context import CryptContext
+from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep
+
+_ALGORITHM = "HS256"
+_InternalAuthorModel = MongoInternalAuthor | SqlInternalAuthor | RedisInternalAuthor
+
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin")
+
+
+def verify_password(plain_password, hashed_password):
+ """Verify that the passwords match
+
+ Args:
+ plain_password: the plain password
+ hashed_password: the hashed password
+
+ Returns:
+ True if the passwords match else False
+ """
+ return pwd_context.verify(plain_password, hashed_password)
+
+
+def get_password_hash(password: str) -> str:
+ """Creates a hash from the given password
+
+ Args:
+ password: the password to hash
+
+ Returns:
+ the password hash
+ """
+ return pwd_context.hash(password)
+
+
+def get_jwt_secret() -> str:
+ """Gets the JWT secret from the environment"""
+ return os.environ.get("JWT_SECRET")
+
+
+async def get_user(
+ sql_store: SqlStoreDep,
+ redis_store: RedisStoreDep,
+ mongo_store: MongoStoreDep,
+ query: dict[str, Any],
+) -> _InternalAuthorModel:
+ """Gets the user instance that matches the given query
+
+ Args:
+ sql_store: the SQL store from which to retrieve the user
+ redis_store: the redis store from which to retrieve the user
+ mongo_store: the mongo store from which to retrieve the user
+ query: the filter that the user must match
+
+ Returns:
+ the matching user
+
+ Raises:
+ HTTPException: Unauthorized
+ """
+ try:
+ results = []
+ if sql_store:
+ results += await sql_store.find(SqlInternalAuthor, query=query, limit=1)
+
+ if mongo_store:
+ results += await mongo_store.find(MongoInternalAuthor, query=query, limit=1)
+
+ if redis_store:
+ results += await redis_store.find(RedisInternalAuthor, query=query, limit=1)
+
+ return results[0]
+ except (InvalidTokenError, IndexError):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+
+async def authenticate_user(
+ sql_store: SqlStoreDep,
+ redis_store: RedisStoreDep,
+ mongo_store: MongoStoreDep,
+ email: str,
+ password: str,
+) -> _InternalAuthorModel:
+ """Authenticates the user of the given email and password
+
+ Args:
+ sql_store: the SQL store from which to retrieve the user
+ redis_store: the redis store from which to retrieve the user
+ mongo_store: the mongo store from which to retrieve the user
+ email: the email of the user
+ password: the password of the user
+
+ Returns:
+ the authenticated user or False
+ """
+ user = await get_user(
+ sql_store=sql_store,
+ redis_store=redis_store,
+ mongo_store=mongo_store,
+ query={"email": {"$eq": email}},
+ )
+ if not verify_password(password, user.hashed_password):
+ return False
+ return user
+
+
+def create_access_token(secret_key: str, data: dict, ttl_minutes: float = 15) -> str:
+ """Creates an access token given a secret key
+
+ Args:
+ secret_key: the JWT secret key for creating the JWT
+ data: the data to encode
+ ttl_minutes: the time to live in minutes
+
+ Returns:
+ the access token
+ """
+ expire = datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes)
+ encoded_jwt = jwt.encode({**data, "exp": expire}, secret_key, algorithm=_ALGORITHM)
+ return encoded_jwt
+
+
+async def get_current_user(
+ sql_store: SqlStoreDep,
+ redis_store: RedisStoreDep,
+ mongo_store: MongoStoreDep,
+ secret_key: Annotated[str, Depends(get_jwt_secret)],
+ token: Annotated[str, Depends(oauth2_scheme)],
+) -> _InternalAuthorModel:
+ """Gets the current user for the given token
+
+ Args:
+ sql_store: the SQL store from which to retrieve the user
+ redis_store: the redis store from which to retrieve the user
+ mongo_store: the mongo store from which to retrieve the user
+ secret_key: the secret key for JWT decoding
+ token: the token for the current user
+
+ Returns:
+ the user
+
+ Raises:
+ HTTPException: Could not validate credentials
+ """
+ payload = jwt.decode(token, secret_key, algorithms=[_ALGORITHM])
+ email = payload.get("sub")
+ return await get_user(
+ sql_store=sql_store,
+ redis_store=redis_store,
+ mongo_store=mongo_store,
+ query={"email": {"$eq": email}},
+ )
+
+
+# Dependencies
+SecretKeyDep = Annotated[str, Depends(get_jwt_secret)]
+CurrentUserDep = Annotated[_InternalAuthorModel, Depends(get_current_user)]
diff --git a/examples/blog/conftest.py b/examples/blog/conftest.py
new file mode 100644
index 0000000..8aae68e
--- /dev/null
+++ b/examples/blog/conftest.py
@@ -0,0 +1,228 @@
+"""Fixtures for tests"""
+
+import os
+from typing import Any
+
+import pytest
+import pytest_asyncio
+import pytest_mock
+from fastapi.testclient import TestClient
+from models import (
+ MongoAuthor,
+ MongoComment,
+ MongoInternalAuthor,
+ MongoPost,
+ MongoTag,
+ RedisAuthor,
+ RedisComment,
+ RedisInternalAuthor,
+ RedisPost,
+ RedisTag,
+ SqlComment,
+ SqlInternalAuthor,
+ SqlPost,
+ SqlTag,
+ SqlTagLink,
+)
+
+from nqlstore import MongoStore, RedisStore, SQLStore
+
+POST_LISTS: list[dict[str, Any]] = [
+ {
+ "title": "School Work",
+ "content": "foo bar man the stuff",
+ },
+ {
+ "title": "Home",
+ "tags": [
+ {"title": "home"},
+ {"title": "art"},
+ ],
+ },
+ {"title": "Boo", "content": "some random stuff", "tags": [{"title": "random"}]},
+]
+
+COMMENT_LIST: list[dict[str, Any]] = [
+ {
+ "content": "Fake comment",
+ },
+ {
+ "content": "Just wondering, who wrote this?",
+ },
+ {
+ "content": "Mann, this is off the charts!",
+ },
+ {
+ "content": "Woo hoo",
+ },
+ {
+ "content": "Not cool. Not cool at all.",
+ },
+]
+
+AUTHOR: dict[str, Any] = {
+ "name": "John Doe",
+ "email": "johndoe@example.com",
+ "password": "password123",
+}
+
+_SQL_DB = "test.db"
+_SQL_URL = f"sqlite+aiosqlite:///{_SQL_DB}"
+_REDIS_URL = "redis://localhost:6379/0"
+_MONGO_URL = "mongodb://localhost:27017"
+_MONGO_DB = "testing"
+ACCESS_TOKEN = "some-token"
+
+
+@pytest.fixture
+def mocked_auth(mocker: pytest_mock.MockerFixture):
+ """Mocks the auth to always return the AUTHOR as valid"""
+ mocker.patch("jwt.encode", return_value=ACCESS_TOKEN)
+ mocker.patch("jwt.decode", return_value={"sub": AUTHOR["email"]})
+ mocker.patch("auth.pwd_context.verify", return_value=True)
+ yield
+
+
+@pytest.fixture
+def client_with_sql(mocked_auth):
+ """The fastapi test client when SQL is enabled"""
+ _reset_env()
+ os.environ["SQL_URL"] = _SQL_URL
+
+ from main import app
+
+ yield TestClient(app)
+ _reset_env()
+
+
+@pytest_asyncio.fixture
+async def client_with_redis(mocked_auth):
+ """The fastapi test client when redis is enabled"""
+ _reset_env()
+ os.environ["REDIS_URL"] = _REDIS_URL
+
+ from main import app
+
+ yield TestClient(app)
+ _reset_env()
+
+
+@pytest.fixture
+def client_with_mongo(mocked_auth):
+ """The fastapi test client when mongodb is enabled"""
+ _reset_env()
+
+ os.environ["MONGO_URL"] = _MONGO_URL
+ os.environ["MONGO_DB"] = _MONGO_DB
+
+ from main import app
+
+ yield TestClient(app)
+ _reset_env()
+
+
+@pytest_asyncio.fixture()
+async def sql_store(mocked_auth):
+ """The sql store stored in memory"""
+ store = SQLStore(uri=_SQL_URL)
+
+ await store.register(
+ [
+ # SqlAuthor,
+ SqlTag,
+ SqlTagLink,
+ SqlPost,
+ SqlComment,
+ SqlInternalAuthor,
+ ]
+ )
+ # insert default user
+ await store.insert(SqlInternalAuthor, [AUTHOR])
+ yield store
+
+ # clean up
+ os.remove(_SQL_DB)
+
+
+@pytest_asyncio.fixture()
+async def mongo_store(mocked_auth):
+ """The mongodb store. Requires a running instance of mongodb"""
+ import pymongo
+
+ store = MongoStore(uri=_MONGO_URL, database=_MONGO_DB)
+ await store.register(
+ [
+ MongoAuthor,
+ MongoTag,
+ MongoPost,
+ MongoComment,
+ MongoInternalAuthor,
+ ]
+ )
+
+ # insert default user
+ await store.insert(MongoInternalAuthor, [AUTHOR])
+
+ yield store
+
+ # clean up
+ client = pymongo.MongoClient(_MONGO_URL) # type: ignore
+ client.drop_database(_MONGO_DB)
+
+
+@pytest_asyncio.fixture
+async def redis_store(mocked_auth):
+ """The redis store. Requires a running instance of redis stack"""
+ import redis
+
+ store = RedisStore(_REDIS_URL)
+ await store.register(
+ [
+ RedisAuthor,
+ RedisTag,
+ RedisPost,
+ RedisComment,
+ RedisInternalAuthor,
+ ]
+ )
+ # insert default user
+ await store.insert(RedisInternalAuthor, [AUTHOR])
+
+ yield store
+
+ # clean up
+ client = redis.Redis("localhost", 6379, 0)
+ client.flushall()
+
+
+@pytest_asyncio.fixture()
+async def sql_posts(sql_store: SQLStore):
+ """A list of posts in the sql store"""
+ records = await sql_store.insert(SqlPost, POST_LISTS)
+ yield records
+
+
+@pytest_asyncio.fixture()
+async def mongo_posts(mongo_store: MongoStore):
+ """A list of posts in the mongo store"""
+
+ records = await mongo_store.insert(MongoPost, POST_LISTS)
+ yield records
+
+
+@pytest_asyncio.fixture()
+async def redis_posts(redis_store: RedisStore):
+ """A list of posts in the redis store"""
+ records = await redis_store.insert(RedisPost, POST_LISTS)
+ yield records
+
+
+def _reset_env():
+ """Resets the environment variables available to the app"""
+ os.environ["SQL_URL"] = ""
+ os.environ["REDIS_URL"] = ""
+ os.environ["MONGO_URL"] = ""
+ os.environ["MONGO_DB"] = "testing"
+ os.environ["JWT_SECRET"] = (
+ "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
+ )
diff --git a/examples/blog/main.py b/examples/blog/main.py
new file mode 100644
index 0000000..7390f45
--- /dev/null
+++ b/examples/blog/main.py
@@ -0,0 +1,296 @@
+import logging
+from contextlib import asynccontextmanager
+from typing import Annotated
+
+from auth import (
+ CurrentUserDep,
+ SecretKeyDep,
+ authenticate_user,
+ create_access_token,
+ get_password_hash,
+)
+from fastapi import Depends, FastAPI, HTTPException, Query, status
+from fastapi.security import OAuth2PasswordRequestForm
+from models import MongoPost, RedisPost, SqlInternalAuthor, SqlPost
+from pydantic import BaseModel
+from schemas import InternalAuthor, PartialPost, Post, TokenResponse
+from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep, clear_stores
+
+_ACCESS_TOKEN_EXPIRE_MINUTES = 30
+
+
+@asynccontextmanager
+async def lifespan(app_: FastAPI):
+ clear_stores()
+ yield
+ clear_stores()
+
+
+app = FastAPI(lifespan=lifespan)
+
+
+@app.post("/signup")
+async def signup(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ payload: InternalAuthor,
+):
+ """Signup a new user"""
+ results = []
+ payload_dict = payload.model_dump(exclude_unset=True)
+ payload_dict["password"] = get_password_hash(payload_dict["password"])
+
+ try:
+ if sql:
+ results += await sql.insert(SqlInternalAuthor, [payload_dict])
+
+ if redis:
+ results += await redis.insert(RedisPost, [payload_dict])
+
+ if mongo:
+ results += await mongo.insert(MongoPost, [payload_dict])
+
+ result = results[0].model_dump(mode="json")
+ return result
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+
+@app.post("/signin")
+async def signin(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ secret_key: SecretKeyDep,
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
+) -> TokenResponse:
+ user = await authenticate_user(
+ sql_store=sql,
+ redis_store=redis,
+ mongo_store=mongo,
+ email=form_data.username,
+ password=form_data.password,
+ )
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ access_token = create_access_token(
+ secret_key=secret_key,
+ data={"sub": user.username},
+ ttl_minutes=_ACCESS_TOKEN_EXPIRE_MINUTES,
+ )
+ return TokenResponse(access_token=access_token, token_type="bearer")
+
+
+@app.get("/posts")
+async def search(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ title: str | None = Query(None),
+ author: str | None = Query(None),
+ tag: str | None = Query(None),
+):
+ """Searches for posts by title, author or tag"""
+
+ results = []
+ query_dict: dict[str, str] = {
+ k: v
+ for k, v in {"title": title, "author.name": author, "tags.title": tag}.items()
+ if v
+ }
+ query = {k: {"$regex": f".*{v}.*", "$options": "i"} for k, v in query_dict.items()}
+ try:
+ if sql:
+ results += await sql.find(SqlPost, query=query)
+
+ if redis:
+ # redis's regex search is not mature so we use its full text search
+ # Unfortunately, redis search does not permit us to search fields that are arrays.
+ redis_query = [
+ (
+ (_get_redis_field(RedisPost, k) == f"{v}")
+ if k == "tags.title"
+ else (_get_redis_field(RedisPost, k) % f"*{v}*")
+ )
+ for k, v in query_dict.items()
+ ]
+ results += await redis.find(RedisPost, *redis_query)
+
+ if mongo:
+ results += await mongo.find(MongoPost, query=query)
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+ return [item.model_dump(mode="json") for item in results]
+
+
+@app.get("/posts/{id_}")
+async def get_one(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ id_: int | str,
+):
+ """Get post by id"""
+ results = []
+ query = {"id": {"$eq": id_}}
+
+ try:
+ if sql:
+ results += await sql.find(SqlPost, query=query, limit=1)
+
+ if redis:
+ results += await redis.find(RedisPost, query=query, limit=1)
+
+ if mongo:
+ results += await mongo.find(MongoPost, query=query, limit=1)
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+ try:
+ return results[0].model_dump(mode="json")
+ except IndexError:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
+
+
+@app.post("/posts")
+async def create_one(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ current_user: CurrentUserDep,
+ payload: Post,
+):
+ """Create a post"""
+ results = []
+ payload_dict = payload.model_dump(exclude_unset=True)
+ payload_dict["author"] = current_user.model_dump()
+
+ try:
+ if sql:
+ results += await sql.insert(SqlPost, [payload_dict])
+
+ if redis:
+ results += await redis.insert(RedisPost, [payload_dict])
+
+ if mongo:
+ results += await mongo.insert(MongoPost, [payload_dict])
+
+ result = results[0].model_dump(mode="json")
+ return result
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+
+@app.put("/posts/{id_}")
+async def update_one(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ current_user: CurrentUserDep,
+ id_: int | str,
+ payload: PartialPost,
+):
+ """Update a post"""
+ results = []
+ query = {"id": {"$eq": id_}}
+ updates = payload.model_dump(exclude_unset=True)
+ user_dict = current_user.model_dump()
+
+ if "comments" in updates:
+ # just resetting the author of all comments to current user.
+ # This is probably logically wrong.
+ # It is just here for illustration
+ updates["comments"] = [
+ {**item, "author": user_dict} for item in updates["comments"]
+ ]
+
+ try:
+ if sql:
+ results += await sql.update(SqlPost, query=query, updates=updates)
+
+ if redis:
+ results += await redis.update(RedisPost, query=query, updates=updates)
+
+ if mongo:
+ results += await mongo.update(MongoPost, query=query, updates=updates)
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+ try:
+ return results[0].model_dump(mode="json")
+ except IndexError:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
+
+
+@app.delete("/posts/{id_}")
+async def delete_one(
+ sql: SqlStoreDep,
+ redis: RedisStoreDep,
+ mongo: MongoStoreDep,
+ id_: int | str,
+):
+ """Delete a post"""
+ results = []
+ query = {"id": {"$eq": id_}}
+
+ try:
+ if sql:
+ results += await sql.delete(SqlPost, query=query)
+
+ if redis:
+ results += await redis.delete(RedisPost, query=query)
+
+ if mongo:
+ results += await mongo.delete(MongoPost, query=query)
+ except Exception as exp:
+ logging.error(exp)
+ raise exp
+
+ try:
+ return results[0].model_dump(mode="json")
+ except IndexError:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
+
+
+def _get_redis_field(model: type[BaseModel], path: str):
+ """Retrieves the Field at the given path, which may or may not be dotted
+
+ Args:
+ path: the path to the field where dots signify relations; example books.title
+ model: the parent model
+
+ Returns:
+ the FieldInfo at the given path
+
+ Raises:
+ ValueError: no field '{path}' found on '{parent}'
+ """
+ path_segments = path.split(".")
+ current_parent = model
+
+ field = None
+ for idx, part in enumerate(path_segments):
+ field = getattr(current_parent, part)
+ try:
+ current_parent = field
+ except AttributeError as exp:
+ if idx == len(path_segments) - 1:
+ break
+ raise exp
+
+ if field is None:
+ raise ValueError(f"no field '{path}' found on '{model}'")
+
+ return field
diff --git a/examples/blog/models.py b/examples/blog/models.py
new file mode 100644
index 0000000..e03688b
--- /dev/null
+++ b/examples/blog/models.py
@@ -0,0 +1,64 @@
+"""Models that are saved in storage"""
+
+from schemas import Author, Comment, InternalAuthor, Post, Tag, TagLink
+
+from nqlstore import (
+ EmbeddedJsonModel,
+ EmbeddedMongoModel,
+ JsonModel,
+ MongoModel,
+ SQLModel,
+)
+
+# mongo models
+MongoInternalAuthor = MongoModel("MongoInternalAuthor", InternalAuthor)
+MongoAuthor = EmbeddedMongoModel("MongoAuthor", Author)
+MongoComment = EmbeddedMongoModel(
+ "MongoComment", Comment, embedded_models={"author": MongoAuthor}
+)
+MongoTag = EmbeddedMongoModel("MongoTag", Tag)
+MongoPost = MongoModel(
+ "MongoPost",
+ Post,
+ embedded_models={
+ "author": MongoAuthor | None,
+ "comments": list[MongoComment],
+ "tags": list[MongoTag],
+ },
+)
+
+
+# redis models
+RedisInternalAuthor = JsonModel("RedisInternalAuthor", InternalAuthor)
+RedisAuthor = EmbeddedJsonModel("RedisAuthor", Author)
+RedisComment = EmbeddedJsonModel(
+ "RedisComment", Comment, embedded_models={"author": RedisAuthor}
+)
+RedisTag = EmbeddedJsonModel("RedisTag", Tag)
+RedisPost = JsonModel(
+ "RedisPost",
+ Post,
+ embedded_models={
+ "author": RedisAuthor | None,
+ "comments": list[RedisComment],
+ "tags": list[RedisTag],
+ },
+)
+
+# sqlite models
+SqlInternalAuthor = SQLModel("SqlInternalAuthor", InternalAuthor)
+SqlComment = SQLModel(
+ "SqlComment", Comment, relationships={"author": SqlInternalAuthor | None}
+)
+SqlTagLink = SQLModel("SqlTagLink", TagLink)
+SqlTag = SQLModel("SqlTag", Tag)
+SqlPost = SQLModel(
+ "SqlPost",
+ Post,
+ relationships={
+ "author": SqlInternalAuthor | None,
+ "comments": list[SqlComment],
+ "tags": list[SqlTag],
+ },
+ link_models={"tags": SqlTagLink},
+)
diff --git a/examples/blog/requirements.txt b/examples/blog/requirements.txt
new file mode 100644
index 0000000..c546830
--- /dev/null
+++ b/examples/blog/requirements.txt
@@ -0,0 +1,10 @@
+fastapi[standard]~=0.115.8
+nqlstore[all]
+black~=25.1.0
+isort~=6.0.0
+pytest~=8.3.4
+pytest-asyncio~=0.25.3
+passlib[bcrypt]~=1.7.4
+pyjwt~=2.10.1
+pytest-freezegun~=0.4.2
+pytest-mock~=3.14.0
diff --git a/examples/blog/schemas.py b/examples/blog/schemas.py
new file mode 100644
index 0000000..8fac64f
--- /dev/null
+++ b/examples/blog/schemas.py
@@ -0,0 +1,96 @@
+"""Schemas for the application"""
+
+from datetime import datetime
+
+from pydantic import BaseModel
+from utils import Partial, current_timestamp
+
+from nqlstore import Field, Relationship
+
+
+class Author(BaseModel):
+ """The author as returned to the user"""
+
+ name: str = Field(index=True, full_text_search=True)
+
+
+class InternalAuthor(Author):
+ """The author as saved in database"""
+
+ password: str = Field()
+ email: str = Field(index=True)
+
+
+class Post(BaseModel):
+ """The post"""
+
+ title: str = Field(index=True, full_text_search=True)
+ content: str | None = Field(default="")
+ author_id: int | None = Field(
+ default=None,
+ foreign_key="sqlinternalauthor.id",
+ disable_on_mongo=True,
+ disable_on_redis=True,
+ )
+ author: Author | None = Relationship(default=None)
+ comments: list["Comment"] = Relationship(default=[])
+ tags: list["Tag"] = Relationship(default=[], link_model="TagLink")
+ created_at: str = Field(index=True, default_factory=current_timestamp)
+ updated_at: str = Field(index=True, default_factory=current_timestamp)
+
+
+class Comment(BaseModel):
+ """The comment on a post"""
+
+ post_id: int | None = Field(
+ default=None,
+ foreign_key="sqlpost.id",
+ disable_on_mongo=True,
+ disable_on_redis=True,
+ )
+ content: str | None = Field(default="")
+ author_id: int | None = Field(
+ default=None,
+ foreign_key="sqlinternalauthor.id",
+ disable_on_mongo=True,
+ disable_on_redis=True,
+ )
+ author: Author | None = Relationship(default=None)
+ created_at: str = Field(index=True, default_factory=current_timestamp)
+ updated_at: str = Field(index=True, default_factory=current_timestamp)
+
+
+class TagLink(BaseModel):
+ """The SQL-only join table between tags and posts"""
+
+ post_id: int | None = Field(
+ default=None,
+ foreign_key="sqlpost.id",
+ primary_key=True,
+ disable_on_mongo=True,
+ disable_on_redis=True,
+ )
+ tag_id: int | None = Field(
+ default=None,
+ foreign_key="sqltag.id",
+ primary_key=True,
+ disable_on_mongo=True,
+ disable_on_redis=True,
+ )
+
+
+class Tag(BaseModel):
+ """The tags to help searching for posts"""
+
+ title: str = Field(index=True, unique=True)
+
+
+class TokenResponse(BaseModel):
+ """HTTP-only response"""
+
+ access_token: str
+ token_type: str
+
+
+# Partial models
+PartialPost = Partial("PartialPost", Post)
diff --git a/examples/blog/stores.py b/examples/blog/stores.py
new file mode 100644
index 0000000..0c6b575
--- /dev/null
+++ b/examples/blog/stores.py
@@ -0,0 +1,111 @@
+"""Module containing the registry of stores at runtime"""
+
+import os
+from typing import Annotated
+
+from fastapi import Depends
+from models import ( # SqlAuthor,
+ MongoAuthor,
+ MongoComment,
+ MongoInternalAuthor,
+ MongoPost,
+ MongoTag,
+ RedisAuthor,
+ RedisComment,
+ RedisInternalAuthor,
+ RedisPost,
+ RedisTag,
+ SqlComment,
+ SqlInternalAuthor,
+ SqlPost,
+ SqlTag,
+ SqlTagLink,
+)
+
+from nqlstore import MongoStore, RedisStore, SQLStore
+
+_STORES: dict[str, MongoStore | RedisStore | SQLStore] = {}
+
+
+async def get_redis_store() -> RedisStore | None:
+ """Gets the redis store whose URL is specified in the environment"""
+ global _STORES
+
+ if redis_url := os.environ.get("REDIS_URL"):
+ try:
+ return _STORES[redis_url]
+ except KeyError:
+ store = RedisStore(uri=redis_url)
+ await store.register(
+ [
+ RedisAuthor,
+ RedisTag,
+ RedisPost,
+ RedisComment,
+ RedisInternalAuthor,
+ ]
+ )
+ _STORES[redis_url] = store
+ return store
+
+
+async def get_sql_store() -> SQLStore | None:
+ """Gets the sql store whose URL is specified in the environment"""
+ global _STORES
+
+ if sql_url := os.environ.get("SQL_URL"):
+ try:
+ return _STORES[sql_url]
+ except KeyError:
+ store = SQLStore(uri=sql_url)
+ await store.register(
+ [
+ # SqlAuthor,
+ SqlTag,
+ SqlTagLink,
+ SqlPost,
+ SqlComment,
+ SqlInternalAuthor,
+ ]
+ )
+ _STORES[sql_url] = store
+ return store
+
+
+async def get_mongo_store() -> MongoStore | None:
+ """Gets the mongo store whose URL and database are specified in the environment"""
+ global _STORES
+
+ if mongo_url := os.environ.get("MONGO_URL"):
+ mongo_db = os.environ.get("MONGO_DB", "todos")
+ mongo_full_url = f"{mongo_url}/{mongo_db}"
+
+ try:
+ return _STORES[mongo_full_url]
+ except KeyError:
+ store = MongoStore(uri=mongo_url, database=mongo_db)
+ await store.register(
+ [
+ MongoAuthor,
+ MongoTag,
+ MongoPost,
+ MongoComment,
+ MongoInternalAuthor,
+ ]
+ )
+ _STORES[mongo_full_url] = store
+ return store
+
+
+def clear_stores():
+ """Clears the registry of stores
+
+ Important for clean up
+ """
+ global _STORES
+ _STORES.clear()
+
+
+SqlStoreDep = Annotated[SQLStore | None, Depends(get_sql_store)]
+RedisStoreDep = Annotated[RedisStore | None, Depends(get_redis_store)]
+MongoStoreDep = Annotated[MongoStore | None, Depends(get_mongo_store)]
diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py
new file mode 100644
index 0000000..cc15e0a
--- /dev/null
+++ b/examples/blog/test_main.py
@@ -0,0 +1,613 @@
+from datetime import datetime
+from typing import Any
+
+import pytest
+from bson import ObjectId
+from conftest import ACCESS_TOKEN, AUTHOR, COMMENT_LIST, POST_LISTS
+from fastapi.testclient import TestClient
+from main import MongoPost, RedisPost, SqlPost
+
+from nqlstore import MongoStore, RedisStore, SQLStore
+
+_TITLE_SEARCH_TERMS = ["ho", "oo", "work"]
+_TAG_SEARCH_TERMS = ["art", "om"]
+_HEADERS = {"Authorization": f"Bearer {ACCESS_TOKEN}"}
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("post", POST_LISTS)
+async def test_create_sql_post(
+ client_with_sql: TestClient, sql_store: SQLStore, post: dict, freezer
+):
+ """POST to /posts creates a post in sql and returns it"""
+ timestamp = datetime.now().isoformat()
+ with client_with_sql as client:
+ response = client.post("/posts", json=post, headers=_HEADERS)
+
+ got = response.json()
+ post_id = got["id"]
+ raw_tags = post.get("tags", [])
+ resp_tags = got["tags"]
+ expected = {
+ "id": post_id,
+ "title": post["title"],
+ "content": post.get("content", ""),
+ "author": {"id": 1, **AUTHOR},
+ "author_id": 1,
+ "tags": [
+ {
+ **raw,
+ "id": resp["id"],
+ }
+ for raw, resp in zip(raw_tags, resp_tags)
+ ],
+ "comments": [],
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+
+ db_query = {"id": {"$eq": post_id}}
+ db_results = await sql_store.find(SqlPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump()
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("post", POST_LISTS)
+async def test_create_redis_post(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ post: dict,
+ freezer,
+):
+ """POST to /posts creates a post in redis and returns it"""
+ timestamp = datetime.now().isoformat()
+ with client_with_redis as client:
+ response = client.post("/posts", json=post, headers=_HEADERS)
+
+ got = response.json()
+ post_id = got["id"]
+ raw_tags = post.get("tags", [])
+ resp_tags = got["tags"]
+ expected = {
+ "id": post_id,
+ "title": post["title"],
+ "content": post.get("content", ""),
+ "author": {**got["author"], **AUTHOR},
+ "pk": post_id,
+ "tags": [
+ {
+ **raw,
+ "id": resp["id"],
+ "pk": resp["pk"],
+ }
+ for raw, resp in zip(raw_tags, resp_tags)
+ ],
+ "comments": [],
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+
+ db_query = {"id": {"$eq": post_id}}
+ db_results = await redis_store.find(RedisPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("post", POST_LISTS)
+async def test_create_mongo_post(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ post: dict,
+ freezer,
+):
+ """POST to /posts creates a post in redis and returns it"""
+ timestamp = datetime.now().isoformat()
+ with client_with_mongo as client:
+ response = client.post("/posts", json=post, headers=_HEADERS)
+
+ got = response.json()
+ post_id = got["id"]
+ raw_tags = post.get("tags", [])
+ expected = {
+ "id": post_id,
+ "title": post["title"],
+ "content": post.get("content", ""),
+ "author": {"name": AUTHOR["name"]},
+ "tags": raw_tags,
+ "comments": [],
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+
+ db_query = {"_id": {"$eq": ObjectId(post_id)}}
+ db_results = await mongo_store.find(MongoPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_update_sql_post(
+ client_with_sql: TestClient,
+ sql_store: SQLStore,
+ sql_posts: list[SqlPost],
+ index: int,
+ freezer,
+):
+ """PUT to /posts/{id} updates the sql post of given id and returns updated version"""
+ timestamp = datetime.now().isoformat()
+ with client_with_sql as client:
+ post = sql_posts[index]
+ post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True)
+ id_ = post.id
+ update = {
+ **post_dict,
+ "title": "some other title",
+ "tags": [
+ *post_dict["tags"],
+ {"title": "another one"},
+ {"title": "another one again"},
+ ],
+ "comments": [*post_dict["comments"], *COMMENT_LIST[index:]],
+ }
+
+ response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS)
+
+ got = response.json()
+ expected = {
+ **post.model_dump(mode="json"),
+ **update,
+ "comments": [
+ {
+ **raw,
+ "id": final["id"],
+ "post_id": final["post_id"],
+ "author_id": 1,
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+ for raw, final in zip(update["comments"], got["comments"])
+ ],
+ "tags": [
+ {
+ **raw,
+ "id": final["id"],
+ }
+ for raw, final in zip(update["tags"], got["tags"])
+ ],
+ }
+ db_query = {"id": {"$eq": id_}}
+ db_results = await sql_store.find(SqlPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_update_redis_post(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ redis_posts: list[RedisPost],
+ index: int,
+ freezer,
+):
+ """PUT to /posts/{id} updates the redis post of given id and returns updated version"""
+ timestamp = datetime.now().isoformat()
+ with client_with_redis as client:
+ post = redis_posts[index]
+ post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True)
+ id_ = post.id
+ update = {
+ "title": "some other title",
+ "tags": [
+ *post_dict.get("tags", []),
+ {"title": "another one"},
+ {"title": "another one again"},
+ ],
+ "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]],
+ }
+
+ response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS)
+
+ got = response.json()
+ expected = {
+ **post.model_dump(mode="json"),
+ **update,
+ "comments": [
+ {
+ **raw,
+ "id": final["id"],
+ "author": final["author"],
+ "pk": final["pk"],
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+ for raw, final in zip(update["comments"], got["comments"])
+ ],
+ "tags": [
+ {
+ **raw,
+ "id": final["id"],
+ "pk": final["pk"],
+ }
+ for raw, final in zip(update["tags"], got["tags"])
+ ],
+ }
+ db_query = {"id": {"$eq": id_}}
+ db_results = await redis_store.find(RedisPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+ expected_in_db = {
+ **expected,
+ "tags": [
+ {
+ **raw,
+ "id": final["id"],
+ "pk": final["pk"],
+ }
+ for raw, final in zip(expected["tags"], record_in_db["tags"])
+ ],
+ "comments": [
+ {
+ **raw,
+ "id": final["id"],
+ "pk": final["pk"],
+ }
+ for raw, final in zip(expected["comments"], record_in_db["comments"])
+ ],
+ }
+
+ assert got == expected
+ assert record_in_db == expected_in_db
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_update_mongo_post(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ mongo_posts: list[MongoPost],
+ index: int,
+ freezer,
+):
+ """PUT to /posts/{id} updates the mongo post of given id and returns updated version"""
+ timestamp = datetime.now().isoformat()
+ with client_with_mongo as client:
+ post = mongo_posts[index]
+ post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True)
+ id_ = post.id
+ update = {
+ "title": "some other title",
+ "tags": [
+ *post_dict.get("tags", []),
+ {"title": "another one"},
+ {"title": "another one again"},
+ ],
+ "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]],
+ }
+
+ response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS)
+
+ got = response.json()
+ expected = {
+ **post.model_dump(mode="json"),
+ **update,
+ "comments": [
+ {
+ **raw,
+ "author": final["author"],
+ "created_at": timestamp,
+ "updated_at": timestamp,
+ }
+ for raw, final in zip(update["comments"], got["comments"])
+ ],
+ }
+ db_query = {"_id": {"$eq": id_}}
+ db_results = await mongo_store.find(MongoPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_delete_sql_post(
+ client_with_sql: TestClient,
+ sql_store: SQLStore,
+ sql_posts: list[SqlPost],
+ index: int,
+):
+ """DELETE /posts/{id} deletes the sql post of given id and returns deleted version"""
+ with client_with_sql as client:
+ post = sql_posts[index]
+ id_ = post.id
+
+ response = client.delete(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"id": {"$eq": id_}}
+ db_results = await sql_store.find(SqlPost, query=db_query, limit=1)
+
+ assert got == expected
+ assert db_results == []
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_delete_redis_post(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ redis_posts: list[RedisPost],
+ index: int,
+):
+ """DELETE /posts/{id} deletes the redis post of given id and returns deleted version"""
+ with client_with_redis as client:
+ post = redis_posts[index]
+ id_ = post.id
+
+ response = client.delete(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"id": {"$eq": id_}}
+ db_results = await redis_store.find(RedisPost, query=db_query, limit=1)
+
+ assert got == expected
+ assert db_results == []
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_delete_mongo_post(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ mongo_posts: list[MongoPost],
+ index: int,
+):
+ """DELETE /posts/{id} deletes the mongo post of given id and returns deleted version"""
+ with client_with_mongo as client:
+ post = mongo_posts[index]
+ id_ = post.id
+
+ response = client.delete(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"_id": {"$eq": id_}}
+ db_results = await mongo_store.find(MongoPost, query=db_query, limit=1)
+
+ assert got == expected
+ assert db_results == []
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_read_one_sql_post(
+ client_with_sql: TestClient,
+ sql_store: SQLStore,
+ sql_posts: list[SqlPost],
+ index: int,
+):
+ """GET /posts/{id} gets the sql post of given id"""
+ with client_with_sql as client:
+ post = sql_posts[index]
+ id_ = post.id
+
+ response = client.get(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"id": {"$eq": id_}}
+ db_results = await sql_store.find(SqlPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_read_one_redis_post(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ redis_posts: list[RedisPost],
+ index: int,
+):
+ """GET /posts/{id} gets the redis post of given id"""
+ with client_with_redis as client:
+ post = redis_posts[index]
+ id_ = post.id
+
+ response = client.get(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"id": {"$eq": id_}}
+ db_results = await redis_store.find(RedisPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("index", range(len(POST_LISTS)))
+async def test_read_one_mongo_post(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ mongo_posts: list[MongoPost],
+ index: int,
+):
+ """GET /posts/{id} gets the mongo post of given id"""
+ with client_with_mongo as client:
+ post = mongo_posts[index]
+ id_ = post.id
+
+ response = client.get(f"/posts/{id_}")
+
+ got = response.json()
+ expected = post.model_dump(mode="json")
+
+ db_query = {"_id": {"$eq": id_}}
+ db_results = await mongo_store.find(MongoPost, query=db_query, limit=1)
+ record_in_db = db_results[0].model_dump(mode="json")
+
+ assert got == expected
+ assert record_in_db == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS)
+async def test_search_sql_by_title(
+ client_with_sql: TestClient,
+ sql_store: SQLStore,
+ sql_posts: list[SqlPost],
+ q: str,
+):
+ """GET /posts?title={} gets all sql posts with title containing search item"""
+ with client_with_sql as client:
+ response = client.get(f"/posts?title={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json") for v in sql_posts if q in v.title.lower()
+ ]
+
+ assert got == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS)
+async def test_search_redis_by_title(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ redis_posts: list[RedisPost],
+ q: str,
+):
+ """GET /posts?title={} gets all redis posts with title containing search item"""
+ with client_with_redis as client:
+ response = client.get(f"/posts?title={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json") for v in redis_posts if q in v.title.lower()
+ ]
+
+ assert got == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS)
+async def test_search_mongo_by_title(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ mongo_posts: list[MongoPost],
+ q: str,
+):
+ """GET /posts?title={} gets all mongo posts with title containing search item"""
+ with client_with_mongo as client:
+ response = client.get(f"/posts?title={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json") for v in mongo_posts if q in v.title.lower()
+ ]
+
+ assert got == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS)
+async def test_search_sql_by_tag(
+ client_with_sql: TestClient,
+ sql_store: SQLStore,
+ sql_posts: list[SqlPost],
+ q: str,
+):
+ """GET /posts?tag={} gets all sql posts with tag containing search item"""
+ with client_with_sql as client:
+ response = client.get(f"/posts?tag={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json")
+ for v in sql_posts
+ if any([q in tag.title.lower() for tag in v.tags])
+ ]
+
+ assert got == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", ["random", "another one", "another one again"])
+async def test_search_redis_by_tag(
+ client_with_redis: TestClient,
+ redis_store: RedisStore,
+ redis_posts: list[RedisPost],
+ q: str,
+):
+ """GET /posts?tag={} gets all redis posts with tag containing search item. Partial searches nit supported."""
+ with client_with_redis as client:
+ response = client.get(f"/posts?tag={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json")
+ for v in redis_posts
+ if any([q in tag.title.lower() for tag in v.tags])
+ ]
+
+ assert got == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS)
+async def test_search_mongo_by_tag(
+ client_with_mongo: TestClient,
+ mongo_store: MongoStore,
+ mongo_posts: list[MongoPost],
+ q: str,
+):
+ """GET /posts?tag={} gets all mongo posts with tag containing search item"""
+ with client_with_mongo as client:
+ response = client.get(f"/posts?tag={q}")
+
+ got = response.json()
+ expected = [
+ v.model_dump(mode="json")
+ for v in mongo_posts
+ if any([q in tag.title.lower() for tag in v.tags])
+ ]
+
+ assert got == expected
+
+
+def _get_id(item: Any) -> Any:
+ """Gets the id of the given record
+
+ Args:
+ item: the record whose id is to be obtained
+
+ Returns:
+ the id of the record
+ """
+ try:
+ return item.id
+ except AttributeError:
+ return item.pk
diff --git a/examples/blog/utils.py b/examples/blog/utils.py
new file mode 100644
index 0000000..4a917c0
--- /dev/null
+++ b/examples/blog/utils.py
@@ -0,0 +1,61 @@
+"""Some random utilities for the app"""
+
+import copy
+import sys
+from datetime import datetime
+from typing import Any, Literal, Optional, TypeVar, get_args
+
+from pydantic import BaseModel, create_model
+from pydantic.main import IncEx
+
+from nqlstore._field import FieldInfo
+
+_T = TypeVar("_T", bound=BaseModel)
+
+
+def current_timestamp() -> str:
+ """Gets the current timestamp as an timezone naive ISO format string
+
+ Returns:
+ string of the current datetime
+ """
+ return datetime.now().isoformat()
+
+
+def Partial(name: str, model: type[_T]) -> type[_T]:
+ """Creates a partial schema from another schema, with all fields optional
+
+ Args:
+ name: the name of the model
+ model: the original model
+
+ Returns:
+ A new model with all its fields optional
+ """
+ fields = {
+ k: (_make_optional(v.annotation), None)
+ for k, v in model.model_fields.items() # type: str, FieldInfo
+ }
+
+ return create_model(
+ name,
+ # module of the calling function
+ __module__=sys._getframe(1).f_globals["__name__"],
+ __doc__=model.__doc__,
+ __base__=(model,),
+ **fields,
+ )
+
+
+def _make_optional(type_: type) -> type:
+ """Makes a type optional if not optional
+
+ Args:
+ type_: the type to make optional
+
+ Returns:
+ the optional type
+ """
+ if type(None) in get_args(type_):
+ return type_
+ return type_ | None
diff --git a/examples/todos/schemas.py b/examples/todos/schemas.py
index 3fa0d3c..d765343 100644
--- a/examples/todos/schemas.py
+++ b/examples/todos/schemas.py
@@ -9,7 +9,7 @@ class TodoList(BaseModel):
"""A list of Todos"""
name: str = Field(index=True, full_text_search=True)
- description: str | None = None
+ description: str | None = Field(default=None)
todos: list["Todo"] = Relationship(back_populates="parent", default=[])
diff --git a/nqlstore/_compat.py b/nqlstore/_compat.py
index 770aebc..f22a9ba 100644
--- a/nqlstore/_compat.py
+++ b/nqlstore/_compat.py
@@ -41,9 +41,16 @@
sql imports; and their default if sqlmodel is missing
"""
try:
- from sqlalchemy import Column, Table
+ from sqlalchemy import Column, Table, func
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import create_async_engine
- from sqlalchemy.orm import InstrumentedAttribute, RelationshipProperty, subqueryload
+ from sqlalchemy.orm import (
+ InstrumentedAttribute,
+ RelationshipDirection,
+ RelationshipProperty,
+ subqueryload,
+ )
from sqlalchemy.orm.exc import DetachedInstanceError
from sqlalchemy.sql._typing import (
_ColumnExpressionArgument,
@@ -58,6 +65,7 @@
from sqlmodel.main import IncEx, NoArgAnyCallable, OnDeleteType
from sqlmodel.main import RelationshipInfo as _RelationshipInfo
except ImportError:
+ import types
from typing import Mapping, Optional, Sequence
from typing import Set
from typing import Set as _ColumnExpressionArgument
@@ -75,14 +83,16 @@
OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"]
Column = Any
create_async_engine = lambda *a, **k: dict(**k)
- delete = insert = select = update = create_async_engine
+ pg_insert = sqlite_insert = delete = insert = select = update = create_async_engine
AsyncSession = Any
- RelationshipProperty = Set
+ RelationshipDirection = RelationshipProperty = Set
Table = Set
InstrumentedAttribute = Set
subqueryload = lambda *a, **kwargs: dict(**kwargs)
DetachedInstanceError = RuntimeError
IncEx = Set[Any] | dict
+ func = types.ModuleType("func")
+ func.max = lambda *a, **kwargs: dict(**kwargs)
class _SqlFieldInfo(_FieldInfo): ...
diff --git a/nqlstore/_field.py b/nqlstore/_field.py
index c980347..026ef3e 100644
--- a/nqlstore/_field.py
+++ b/nqlstore/_field.py
@@ -394,6 +394,7 @@ def get_field_definitions(
schema: type[ModelT],
embedded_models: dict[str, Type] | None = None,
relationships: dict[str, Type] | None = None,
+ link_models: dict[str, Type] | None = None,
is_for_redis: bool = False,
is_for_mongo: bool = False,
is_for_sql: bool = False,
@@ -404,6 +405,8 @@ def get_field_definitions(
schema: the model schema class
embedded_models: the map of embedded models as :
relationships: the map of relationships as :
+ link_models: a map of :Model class for all link (through)
+ tables in many-to-many relationships
is_for_redis: whether the definitions are for redis or not
is_for_mongo: whether the definitions are for mongo or not
is_for_sql: whether the definitions are for sql or not
@@ -417,16 +420,24 @@ def get_field_definitions(
if relationships is None:
relationships = {}
+ if link_models is None:
+ link_models = {}
+
fields = {}
for field_name, field in schema.model_fields.items(): # type: str, FieldInfo
class_field_definition = _get_class_field_definition(field)
- if is_for_redis and getattr(class_field_definition, "disable_on_redis", False):
+ if not isinstance(class_field_definition, (FieldInfo, RelationshipInfo)):
+ raise TypeError(
+ f"field '{schema.__name__}.{field_name}' was not initialized with a {Field.__name__}() or {Relationship.__name__}()"
+ )
+
+ if is_for_redis and class_field_definition.disable_on_redis:
continue
- if is_for_mongo and getattr(class_field_definition, "disable_on_mongo", False):
+ if is_for_mongo and class_field_definition.disable_on_mongo:
continue
- if is_for_sql and getattr(class_field_definition, "disable_on_sql", False):
+ if is_for_sql and class_field_definition.disable_on_sql:
continue
field_type = field.annotation
@@ -441,6 +452,7 @@ def get_field_definitions(
field_type = relationships[field_name]
# redefine the class so that SQLModel can redo its thing
field_info = class_field_definition
+ field_info.link_model = link_models.get(field_name)
fields[field_name] = (field_type, field_info)
return fields
diff --git a/nqlstore/_redis.py b/nqlstore/_redis.py
index 2022ad3..fc18866 100644
--- a/nqlstore/_redis.py
+++ b/nqlstore/_redis.py
@@ -259,6 +259,7 @@ class _EmbeddedJsonModelMeta(_EmbeddedJsonModel, abc.ABC):
"""Base model for all EmbeddedJsonModels. Helpful with typing"""
id: str | None
+ __embedded_models__: dict
@classmethod
def set_db(cls, db: Redis):
@@ -275,9 +276,26 @@ def set_db(cls, db: Redis):
except AttributeError:
cls.Meta.database = db
+ # cache the embedded models on the class
+ embedded_models = getattr(cls, "__embedded_models__", None)
+ if embedded_models is None:
+ embedded_models = [
+ model
+ for field in cls.model_fields.values() # type: FieldInfo
+ for model in _get_embed_models(field.annotation)
+ ]
+ setattr(cls, "__embedded_models__", embedded_models)
+
+ # set db on embedded models also
+ for model in cls.__embedded_models__:
+ model.set_db(db)
+
def EmbeddedJsonModel(
- name: str, schema: type[ModelT], /
+ name: str,
+ schema: type[ModelT],
+ /,
+ embedded_models: dict[str, Type] = None,
) -> type[_EmbeddedJsonModelMeta] | type[ModelT]:
"""Creates a new EmbeddedJsonModel for the given schema for redis
@@ -288,11 +306,14 @@ def EmbeddedJsonModel(
Args:
name: the name of the model
schema: the schema from which the model is to be made
+ embedded_models: a dict of embedded models of : annotation
Returns:
a EmbeddedJsonModel model class with the given name
"""
- fields = get_field_definitions(schema, embedded_models=None, is_for_redis=True)
+ fields = get_field_definitions(
+ schema, embedded_models=embedded_models, is_for_redis=True
+ )
return create_model(
name,
diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py
index 26b6ee1..18346bf 100644
--- a/nqlstore/_sql.py
+++ b/nqlstore/_sql.py
@@ -1,8 +1,9 @@
"""SQL implementation"""
+import copy
import sys
from collections.abc import Mapping, MutableMapping
-from typing import Any, Dict, Iterable, Literal, TypeVar, Union
+from typing import Any, Dict, Iterable, Literal, Sequence, TypeVar, Union
from pydantic import create_model
from pydantic.main import ModelT
@@ -14,6 +15,7 @@
DetachedInstanceError,
IncEx,
InstrumentedAttribute,
+ RelationshipDirection,
RelationshipProperty,
Table,
_ColumnExpressionArgument,
@@ -21,8 +23,11 @@
_SQLModel,
create_async_engine,
delete,
+ func,
insert,
+ pg_insert,
select,
+ sqlite_insert,
subqueryload,
update,
)
@@ -30,8 +35,86 @@
from .query.parsers import QueryParser
from .query.selectors import QuerySelector
-_T = TypeVar("_T", bound=_SQLModel)
_Filter = _ColumnExpressionArgument[bool] | bool
+_T = TypeVar("_T")
+
+
+class _SQLModelMeta(_SQLModel):
+ """The base class for all SQL models"""
+
+ id: int | None = Field(default=None, primary_key=True)
+ __rel_field_cache__: dict = {}
+ """dict of (name, Field) that have associated relationships"""
+
+ @classmethod
+ def __relational_fields__(cls) -> dict[str, Any]:
+ """dict of (name, Field) that have associated relationships"""
+
+ cls_fullname = f"{cls.__module__}.{cls.__qualname__}"
+ try:
+ return cls.__rel_field_cache__[cls_fullname]
+ except KeyError:
+ value = {
+ k: v
+ for k, v in cls.__mapper__.all_orm_descriptors.items()
+ if isinstance(v.property, RelationshipProperty)
+ }
+ cls.__rel_field_cache__[cls_fullname] = value
+ return value
+
+ def model_dump(
+ self,
+ *,
+ mode: Union[Literal["json", "python"], str] = "python",
+ include: IncEx = None,
+ exclude: IncEx = None,
+ context: Union[Dict[str, Any], None] = None,
+ by_alias: bool = False,
+ exclude_unset: bool = False,
+ exclude_defaults: bool = False,
+ exclude_none: bool = False,
+ round_trip: bool = False,
+ warnings: Union[bool, Literal["none", "warn", "error"]] = True,
+ serialize_as_any: bool = False,
+ ) -> Dict[str, Any]:
+ data = super().model_dump(
+ mode=mode,
+ include=include,
+ exclude=exclude,
+ context=context,
+ by_alias=by_alias,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ round_trip=round_trip,
+ warnings=warnings,
+ serialize_as_any=serialize_as_any,
+ )
+ relations_mappers = self.__class__.__relational_fields__()
+ for k, field in relations_mappers.items():
+ if exclude is None or k not in exclude:
+ try:
+ value = getattr(self, k, None)
+ except DetachedInstanceError:
+ # ignore lazy loaded values
+ continue
+
+ if value is not None or not exclude_none:
+ data[k] = _serialize_embedded(
+ value,
+ field=field,
+ mode=mode,
+ context=context,
+ by_alias=by_alias,
+ exclude_unset=exclude_unset,
+ exclude_defaults=exclude_defaults,
+ exclude_none=exclude_none,
+ round_trip=round_trip,
+ warnings=warnings,
+ serialize_as_any=serialize_as_any,
+ )
+
+ return data
class SQLStore(BaseStore):
@@ -41,24 +124,29 @@ def __init__(self, uri: str, parser: QueryParser | None = None, **kwargs):
super().__init__(uri, parser=parser, **kwargs)
self._engine = create_async_engine(uri, **kwargs)
- async def register(self, models: list[type[_T]], checkfirst: bool = True):
- tables = [v.__table__ for v in models]
+ async def register(
+ self, models: list[type[_SQLModelMeta]], checkfirst: bool = True
+ ):
+ tables = [v.__table__ for v in models if hasattr(v, "__table__")]
async with self._engine.begin() as conn:
await conn.run_sync(
_SQLModel.metadata.create_all, tables=tables, checkfirst=checkfirst
)
async def insert(
- self, model: type[_T], items: Iterable[_T | dict], **kwargs
- ) -> list[_T]:
+ self,
+ model: type[_SQLModelMeta],
+ items: Iterable[_SQLModelMeta | dict],
+ **kwargs,
+ ) -> list[_SQLModelMeta]:
parsed_items = [
v if isinstance(v, model) else model.model_validate(v) for v in items
]
- relations_mapper = _get_relations_mapper(model)
+ relations_mapper = model.__relational_fields__()
async with AsyncSession(self._engine) as session:
- stmt = insert(model).returning(model)
- cursor = await session.stream_scalars(stmt, parsed_items)
+ insert_stmt = await _get_insert_func(session, model=model)
+ cursor = await session.stream_scalars(insert_stmt, parsed_items)
results = await cursor.all()
result_ids = [v.id for v in results]
@@ -72,254 +160,116 @@ async def insert(
for idx, record in enumerate(items):
parent = results[idx]
raw_value = _get_key_or_prop(record, k)
- embedded_value = _parse_embedded(raw_value, field, parent)
- if isinstance(embedded_value, Iterable):
- embedded_values += embedded_value
- elif isinstance(embedded_value, _SQLModel):
+ embedded_value = _embed_value(parent, field, raw_value)
+
+ if isinstance(embedded_value, _SQLModel):
embedded_values.append(embedded_value)
+ elif isinstance(embedded_value, Iterable):
+ embedded_values += embedded_value
# insert the related items
if len(embedded_values) > 0:
field_model = field.property.mapper.class_
- embed_stmt = insert(field_model).returning(field_model)
+ embed_stmt = await _get_insert_func(session, model=field_model)
await session.stream_scalars(embed_stmt, embedded_values)
+ # update the updated parents
+ session.add_all(results)
+
await session.commit()
refreshed_results = await self.find(model, model.id.in_(result_ids))
return list(refreshed_results)
async def find(
self,
- model: type[_T],
+ model: type[_SQLModelMeta],
*filters: _Filter,
query: QuerySelector | None = None,
skip: int = 0,
limit: int | None = None,
sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (),
**kwargs,
- ) -> list[_T]:
+ ) -> list[_SQLModelMeta]:
async with AsyncSession(self._engine) as session:
if query:
filters = (*filters, *self._parser.to_sql(model, query=query))
-
- relations = _get_relations(model)
-
- # eagerly load all relationships so that no validation errors occur due
- # to missing session if there is an attempt to load them lazily later
- eager_load_opts = [subqueryload(v) for v in relations]
-
- filtered_relations = _get_filtered_relations(
- filters=filters,
- relations=relations,
+ return await _find(
+ session, model, *filters, skip=skip, limit=limit, sort=sort
)
- # Note that we need to treat relations that are referenced in the filters
- # differently from those that are not. This is because filtering basing on a relationship
- # requires the use of an inner join. Yet an inner join automatically excludes rows
- # that are have null for a given relationship.
- #
- # An outer join on the other hand would just return all the rows in the left table.
- # We thus need to do an inner join on tables that are being filtered.
- stmt = select(model)
- for rel in filtered_relations:
- stmt = stmt.join_from(model, rel)
-
- cursor = await session.stream_scalars(
- stmt.options(*eager_load_opts)
- .where(*filters)
- .limit(limit)
- .offset(skip)
- .order_by(*sort)
- )
- results = await cursor.all()
- return list(results)
-
async def update(
self,
- model: type[_T],
+ model: type[_SQLModelMeta],
*filters: _Filter,
query: QuerySelector | None = None,
updates: dict | None = None,
**kwargs,
- ) -> list[_T]:
+ ) -> list[_SQLModelMeta]:
+ updates = copy.deepcopy(updates)
async with AsyncSession(self._engine) as session:
if query:
filters = (*filters, *self._parser.to_sql(model, query=query))
- # Construct filters that have sub queries
- relations = _get_relations(model)
- rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters(
- filters=filters,
- relations=relations,
- )
- rel_filters = _to_subquery_based_filters(
- model=model,
- rel_filters=rel_filters,
- relations=relations,
- )
-
- # dealing with nested models in the update
- relations_mapper = _get_relations_mapper(model)
- embedded_updates = {}
- for k in relations_mapper:
- try:
- embedded_updates[k] = updates.pop(k)
- except KeyError:
- pass
-
- stmt = (
- update(model)
- .where(*non_rel_filters, *rel_filters)
- .values(**updates)
- .returning(model.__table__)
+ relational_filters = _get_relational_filters(model, filters)
+ non_relational_filters = _get_non_relational_filters(model, filters)
+
+ # Let's update the fields that are not embedded model fields
+ # and return the affected results
+ results = await _update_non_embedded_fields(
+ session,
+ model,
+ *non_relational_filters,
+ *relational_filters,
+ updates=updates,
)
-
- cursor = await session.stream(stmt)
- raw_results = await cursor.fetchall()
- results = [model.model_validate(row._mapping) for row in raw_results]
result_ids = [v.id for v in results]
- for k, v in embedded_updates.items():
- field = relations_mapper[k]
- field_props = field.property
- field_model = field_props.mapper.class_
- # fk = foreign key
- fk_field_name = field_props.primaryjoin.right.name
- fk_field = getattr(field_model, fk_field_name)
- parent_id_field = field_props.primaryjoin.left.name
-
- # get the foreign keys to use in resetting all affected
- # relationships;
- # get parsed embedded values so that they can replace
- # the old relations.
- # Note: this operation is strictly replace, not patch
- embedded_values = []
- fk_values = []
- for parent in results:
- embedded_value = _parse_embedded(v, field, parent)
- if isinstance(embedded_value, Iterable):
- embedded_values += embedded_value
- fk_values.append(getattr(parent, parent_id_field))
- elif isinstance(embedded_value, _SQLModel):
- embedded_values.append(embedded_value)
- fk_values.append(getattr(parent, parent_id_field))
-
- # insert the related items
- if len(embedded_values) > 0:
- # Reset the relationship; delete all other related items
- # Currently, this operation replaces all past relations
- reset_stmt = delete(field_model).where(fk_field.in_(fk_values))
- await session.stream(reset_stmt)
-
- # insert the latest changes
- embed_stmt = insert(field_model).returning(field_model)
- await session.stream_scalars(embed_stmt, embedded_values)
-
+ # Let's update the embedded fields also
+ await _update_embedded_fields(
+ session, model=model, records=results, updates=updates
+ )
await session.commit()
- return await self.find(model, model.id.in_(result_ids))
+
+ refreshed_results = await self.find(model, model.id.in_(result_ids))
+ return refreshed_results
async def delete(
self,
- model: type[_T],
+ model: type[_SQLModelMeta],
*filters: _Filter,
query: QuerySelector | None = None,
**kwargs,
- ) -> list[_T]:
+ ) -> list[_SQLModelMeta]:
async with AsyncSession(self._engine) as session:
if query:
filters = (*filters, *self._parser.to_sql(model, query=query))
deleted_items = await self.find(model, *filters)
- # Construct filters that have sub queries
- relations = _get_relations(model)
- rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters(
- filters=filters,
- relations=relations,
- )
- rel_filters = _to_subquery_based_filters(
- model=model,
- rel_filters=rel_filters,
- relations=relations,
- )
+ relational_filters = _get_relational_filters(model, filters)
+ non_relational_filters = _get_non_relational_filters(model, filters)
+
exec_options = {}
- if len(rel_filters) > 0:
+ if len(relational_filters) > 0:
exec_options = {"is_delete_using": True}
await session.stream(
delete(model)
- .where(*non_rel_filters, *rel_filters)
+ .where(*non_relational_filters, *relational_filters)
.execution_options(**exec_options),
)
await session.commit()
return deleted_items
-class _SQLModelMeta(_SQLModel):
- """The base class for all SQL models"""
-
- id: int | None = Field(default=None, primary_key=True)
-
- def model_dump(
- self,
- *,
- mode: Union[Literal["json", "python"], str] = "python",
- include: IncEx = None,
- exclude: IncEx = None,
- context: Union[Dict[str, Any], None] = None,
- by_alias: bool = False,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- round_trip: bool = False,
- warnings: Union[bool, Literal["none", "warn", "error"]] = True,
- serialize_as_any: bool = False,
- ) -> Dict[str, Any]:
- data = super().model_dump(
- mode=mode,
- include=include,
- exclude=exclude,
- context=context,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- round_trip=round_trip,
- warnings=warnings,
- serialize_as_any=serialize_as_any,
- )
- relations_mappers = _get_relations_mapper(self.__class__)
- for k, field in relations_mappers.items():
- if exclude is None or k not in exclude:
- try:
- value = getattr(self, k, None)
- except DetachedInstanceError:
- # ignore lazy loaded values
- continue
-
- if value is not None or not exclude_none:
- data[k] = _serialize_embedded(
- value,
- field=field,
- mode=mode,
- context=context,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- round_trip=round_trip,
- warnings=warnings,
- serialize_as_any=serialize_as_any,
- )
-
- return data
-
-
def SQLModel(
name: str,
schema: type[ModelT],
/,
relationships: dict[str, type[Any] | type[Union[Any]]] = None,
+ link_models: dict[str, type[Any]] = None,
+ table: bool = True,
+ **kwargs: Any,
) -> type[_SQLModelMeta] | type[ModelT]:
"""Creates a new SQLModel for the given schema for redis
@@ -331,55 +281,30 @@ def SQLModel(
name: the name of the model
schema: the schema from which the model is to be made
relationships: a map of :annotation for all relationships
+ link_models: a map of :Model class for all link (through)
+ tables in many-to-many relationships
+ table: whether this model should have a table in the database or not;
+ default = True
+ kwargs: key-word args to pass to the SQLModel when defining it
Returns:
a SQLModel model class with the given name
"""
- fields = get_field_definitions(schema, relationships=relationships, is_for_sql=True)
+ fields = get_field_definitions(
+ schema, relationships=relationships, link_models=link_models, is_for_sql=True
+ )
return create_model(
name,
# module of the calling function
__module__=sys._getframe(1).f_globals["__name__"],
__doc__=schema.__doc__,
- __cls_kwargs__={"table": True},
+ __cls_kwargs__={"table": table, **kwargs},
__base__=(_SQLModelMeta,),
**fields,
)
-def _get_relations(model: type[_SQLModel]):
- """Gets all the relational fields of the given model
-
- Args:
- model: the SQL model to inspect
-
- Returns:
- list of Fields that have associated relationships
- """
- return [
- v
- for v in model.__mapper__.all_orm_descriptors.values()
- if isinstance(v.property, RelationshipProperty)
- ]
-
-
-def _get_relations_mapper(model: type[_SQLModel]) -> dict[str, Any]:
- """Gets all the relational fields with their names of the given model
-
- Args:
- model: the SQL model to inspect
-
- Returns:
- dict of (name, Field) that have associated relationships
- """
- return {
- k: v
- for k, v in model.__mapper__.all_orm_descriptors.items()
- if isinstance(v.property, RelationshipProperty)
- }
-
-
def _get_filtered_tables(filters: Iterable[_Filter]) -> list[Table]:
"""Retrieves the tables that have been referenced in the filters
@@ -413,31 +338,51 @@ def _get_filtered_relations(
return [rel for rel in relations if rel.property.target in filtered_tables]
-def _sieve_rel_from_non_rel_filters(
- filters: Iterable[_Filter], relations: Iterable[InstrumentedAttribute[Any]]
-) -> tuple[list[_Filter], list[_Filter]]:
- """Separates relational filters from non-relational ones
+def _get_relational_filters(
+ model: type[_SQLModelMeta],
+ filters: Iterable[_Filter],
+) -> list[_Filter]:
+ """Gets the filters that are concerned with relationships on this model
+
+ The filters returned are in subquery form since 'update' and 'delete'
+ in sqlalchemy do not have join and the only way to attach these filters
+ to the model is through sub queries
Args:
+ model: the model under consideration
filters: the tuple of filters to inspect
- relations: all relations present on the model
Returns:
- tuple(rel, non_rel) where rel = list of relational filters,
- and non_rel = non-relational filters
+ list of filters that are concerned with relationships on this model
"""
- rel_targets = [v.property.target for v in relations]
- rel = []
- non_rel = []
-
- for filter_ in filters:
- operands = filter_.get_children()
- if any([getattr(v, "table", None) in rel_targets for v in operands]):
- rel.append(filter_)
- else:
- non_rel.append(filter_)
+ relationships = list(model.__relational_fields__().values())
+ targets = [v.property.target for v in relationships]
+ plain_filters = [
+ item
+ for item in filters
+ if any([getattr(v, "table", None) in targets for v in item.get_children()])
+ ]
+ return _to_subquery_based_filters(model, plain_filters, relationships)
+
- return rel, non_rel
+def _get_non_relational_filters(
+ model: type[_SQLModelMeta], filters: Iterable[_Filter]
+) -> list[_Filter]:
+ """Gets the filters that are NOT concerned with relationships on this model
+
+ Args:
+ model: the model under consideration
+ filters: the tuple of filters to inspect
+
+ Returns:
+ list of filters that are NOT concerned with relationships on this model
+ """
+ targets = [v.property.target for v in model.__relational_fields__().values()]
+ return [
+ item
+ for item in filters
+ if not any([getattr(v, "table", None) in targets for v in item.get_children()])
+ ]
def _to_subquery_based_filters(
@@ -514,45 +459,116 @@ def _with_value(obj: dict | Any, field: str, value: Any) -> Any:
return obj
-def _parse_embedded(
- value: Iterable[dict | Any] | dict | Any, field: Any, parent: _SQLModel
+def _embed_value(
+ parent: _SQLModel,
+ relationship: Any,
+ value: Iterable[dict | Any] | dict | Any,
) -> Iterable[_SQLModel] | _SQLModel | None:
- """Parses embedded items that can be a single item or many into SQLModels
+ """Embeds in place a given value into the parent basing on the given relationship
+
+ Note that the parent itself is changed to include the value
Args:
- value: the value to parse
- field: the field on which these embedded items are
- parent: the parent SQLModel to which this value is attached
+ parent: the model that contains the given relationships
+ relationship: the given relationship
+ value: the values correspond to the related field
Returns:
- An iterable of SQLModel instances or a single SQLModel instance
- or None if value is None
+ the embedded record(s)
"""
if value is None:
return None
- props = field.property # type: RelationshipProperty[Any]
+ props = relationship.property # type: RelationshipProperty[Any]
wrapper_type = props.collection_class
- field_model = props.mapper.class_
- fk_field = props.primaryjoin.right.name
- parent_id_field = props.primaryjoin.left.name
- fk_value = getattr(parent, parent_id_field)
+ relationship_model = props.mapper.class_
+ parent_foreign_key_field = props.primaryjoin.right.name
+ direction = props.direction
+
+ if direction == RelationshipDirection.MANYTOONE:
+ related_value_id_key = props.primaryjoin.left.name
+ parent_foreign_key_value = value.get(related_value_id_key)
+ # update the foreign key value in the parent
+ setattr(parent, parent_foreign_key_field, parent_foreign_key_value)
+ # create child
+ child = relationship_model.model_validate(value)
+ # update nested relationships
+ for (
+ field_name,
+ field_type,
+ ) in relationship_model.__relational_fields__().items():
+ if isinstance(value, dict):
+ nested_related_value = value.get(field_name)
+ else:
+ nested_related_value = getattr(value, field_name)
+
+ nested_related_records = _embed_value(
+ parent=child, relationship=field_type, value=nested_related_value
+ )
+ setattr(child, field_name, nested_related_records)
+
+ return child
- if issubclass(wrapper_type, (list, tuple, set)):
+ elif direction == RelationshipDirection.ONETOMANY:
+ related_value_id_key = props.primaryjoin.left.name
+ parent_foreign_key_value = getattr(parent, related_value_id_key)
# add a foreign key values to link back to parent
- return wrapper_type(
- [
- field_model.model_validate(_with_value(v, fk_field, fk_value))
- for v in value
- ]
- )
- elif wrapper_type is None:
- # add a foreign key value to link back to parent
- linked_value = _with_value(value, fk_field, fk_value)
- return field_model.model_validate(linked_value)
+ if issubclass(wrapper_type, (list, tuple, set)):
+ embedded_records = []
+ for v in value:
+ child = relationship_model.model_validate(
+ _with_value(v, parent_foreign_key_field, parent_foreign_key_value)
+ )
+
+ # update nested relationships
+ for (
+ field_name,
+ field_type,
+ ) in relationship_model.__relational_fields__().items():
+ if isinstance(v, dict):
+ nested_related_value = v.get(field_name)
+ else:
+ nested_related_value = getattr(v, field_name)
+
+ nested_related_records = _embed_value(
+ parent=child,
+ relationship=field_type,
+ value=nested_related_value,
+ )
+ setattr(child, field_name, nested_related_records)
+
+ embedded_records.append(child)
+
+ return wrapper_type(embedded_records)
+
+ elif direction == RelationshipDirection.MANYTOMANY:
+ if issubclass(wrapper_type, (list, tuple, set)):
+ embedded_records = []
+ for v in value:
+ child = relationship_model.model_validate(v)
+
+ # update nested relationships
+ for (
+ field_name,
+ field_type,
+ ) in relationship_model.__relational_fields__().items():
+ if isinstance(v, dict):
+ nested_related_value = v.get(field_name)
+ else:
+ nested_related_value = getattr(v, field_name)
+ nested_related_records = _embed_value(
+ parent=child,
+ relationship=field_type,
+ value=nested_related_value,
+ )
+ setattr(child, field_name, nested_related_records)
+
+ embedded_records.append(child)
+
+ return wrapper_type(embedded_records)
raise NotImplementedError(
- f"relationship of type annotation {wrapper_type} not supported yet"
+ f"relationship {direction} of type annotation {wrapper_type} not supported yet"
)
@@ -576,12 +592,370 @@ def _serialize_embedded(
props = field.property # type: RelationshipProperty[Any]
wrapper_type = props.collection_class
- if issubclass(wrapper_type, (list, tuple, set)):
+ if wrapper_type is None:
+ return value.model_dump(**kwargs)
+ elif issubclass(wrapper_type, (list, tuple, set)):
# add a foreign key values to link back to parent
return wrapper_type([v.model_dump(**kwargs) for v in value])
- elif wrapper_type is None:
- return value.model_dump(**kwargs)
raise NotImplementedError(
f"relationship of type annotation {wrapper_type} not supported yet"
)
+
+
+async def _get_insert_func(session: AsyncSession, model: type[_SQLModelMeta]):
+ """Gets the insert statement for the given session
+
+ Args:
+ session: the async session connecting to the database
+ model: the model for which the insert statement is to be obtained
+
+ Returns:
+ the insert function
+ """
+ conn = await session.connection()
+ dialect = conn.dialect
+ dialect_name = dialect.name
+
+ native_insert_func = insert
+
+ if dialect_name == "sqlite":
+ native_insert_func = sqlite_insert
+ if dialect_name == "postgresql":
+ native_insert_func = pg_insert
+
+ # insert the embedded items
+ try:
+ # PostgreSQL and SQLite support on_conflict_do_nothing
+ return native_insert_func(model).on_conflict_do_nothing().returning(model)
+ except AttributeError:
+ # MySQL supports prefix("IGNORE")
+ # Other databases might fail at this point
+ return (
+ native_insert_func(model)
+ .prefix_with("IGNORE", dialect="mysql")
+ .returning(model)
+ )
+
+
+async def _update_non_embedded_fields(
+ session: AsyncSession, model: type[_SQLModelMeta], *filters: _Filter, updates: dict
+):
+ """Updates only the non-embedded fields of the model
+
+ It ignores any relationships and returns the updated results
+
+ Args:
+ session: the sqlalchemy session
+ model: the model to be updated
+ filters: the filters against which to match the records that are to be updated
+ updates: the updates to add to each matched record
+
+ Returns:
+ the updated records
+ """
+ non_embedded_updates = _get_non_relational_updates(model, updates)
+ if len(non_embedded_updates) == 0:
+ # if we supplied an empty update dict to update,
+ # there would be an error
+ return await _find(session, model, *filters)
+
+ stmt = update(model).where(*filters).values(**non_embedded_updates).returning(model)
+ cursor = await session.stream_scalars(stmt)
+ return await cursor.fetchall()
+
+
+async def _update_embedded_fields(
+ session: AsyncSession,
+ model: type[_SQLModelMeta],
+ records: list[_SQLModelMeta],
+ updates: dict,
+):
+ """Updates only the embedded fields of the model for the given records
+
+ It ignores any fields in the `updates` dict that are not for embedded models
+ Note: this operation is replaces the values of the embedded fields with the new values
+ passed in the `updates` dictionary as opposed to patching the pre-existing values.
+
+ Args:
+ session: the sqlalchemy session
+ model: the model to be updated
+ records: the db records to update
+ updates: the updates to add to each record
+ """
+ embedded_updates = _get_relational_updates(model, updates)
+ relations_mapper = model.__relational_fields__()
+ for k, v in embedded_updates.items():
+ relationship = relations_mapper[k]
+ link_model = model.__sqlmodel_relationships__[k].link_model
+
+ # this does a replace operation; i.e. removes old values and replaces them with the updates
+ await _bulk_embedded_delete(
+ session, relationship=relationship, data=records, link_model=link_model
+ )
+ await _bulk_embedded_insert(
+ session,
+ relationship=relationship,
+ data=records,
+ link_model=link_model,
+ payload=v,
+ )
+ # FIXME: Should the added records be updated with their embedded values?
+ # update the updated parents
+ session.add_all(records)
+
+
+async def _bulk_embedded_insert(
+ session: AsyncSession,
+ relationship: Any,
+ data: list[_SQLModelMeta],
+ link_model: type[_SQLModelMeta] | None,
+ payload: Iterable[dict] | dict,
+) -> Sequence[_SQLModelMeta] | None:
+ """Inserts the payload into the data following the given relationship
+
+ It updates the database also
+
+ Args:
+ session: the database session
+ relationship: the relationship the payload has with the data's schema
+ link_model: the model for the through table
+ payload: the payload to merge into each record in the data
+
+ Returns:
+ the updated data including the embedded data in each record
+ """
+ relationship_props = relationship.property # type: RelationshipProperty
+ relationship_model = relationship_props.mapper.class_
+
+ parsed_embedded_records = [_embed_value(v, relationship, payload) for v in data]
+
+ insert_stmt = await _get_insert_func(session, model=relationship_model)
+ embedded_cursor = await session.stream_scalars(
+ insert_stmt, _flatten_list(parsed_embedded_records)
+ )
+ embedded_db_records = await embedded_cursor.all()
+
+ parent_embedded_map = [
+ (parent, embedded_db_records[idx : idx + len(_as_list(raw_embedded))])
+ for idx, (parent, raw_embedded) in enumerate(zip(data, parsed_embedded_records))
+ ]
+
+ # insert through table values
+ await _bulk_insert_through_table_data(
+ session,
+ relationship=relationship,
+ link_model=link_model,
+ parent_embedded_map=parent_embedded_map,
+ )
+
+ return data
+
+
+async def _bulk_insert_through_table_data(
+ session: AsyncSession,
+ relationship: Any,
+ link_model: type[_SQLModelMeta] | None,
+ parent_embedded_map: list[tuple[_SQLModelMeta, list[_SQLModelMeta]]],
+):
+ """Inserts the link records into the through-table represented by the link_model
+
+ Args:
+ session: the database session
+ relationship: the relationship the embedded records are based on
+ link_model: the model for the through table
+ parent_embedded_map: the list of tuples of parent and its associated embedded db records
+ """
+ if link_model is not None:
+ relationship_props = relationship.property # type: RelationshipProperty
+ child_id_field_name = relationship_props.secondaryjoin.left.name
+ parent_id_field_name = relationship_props.primaryjoin.left.name
+ child_fk_field_name = relationship_props.secondaryjoin.right.name
+ parent_fk_field_name = relationship_props.primaryjoin.right.name
+
+ link_raw_values = [
+ {
+ parent_fk_field_name: getattr(parent, parent_id_field_name),
+ child_fk_field_name: getattr(child, child_id_field_name),
+ }
+ for parent, children in parent_embedded_map
+ for child in children
+ ]
+
+ next_id = await _get_nextid(session, link_model)
+ link_model_values = [
+ link_model(id=next_id + idx, **v) for idx, v in enumerate(link_raw_values)
+ ]
+
+ insert_stmt = await _get_insert_func(session, model=link_model)
+ await session.stream_scalars(insert_stmt, link_model_values)
+
+
+async def _bulk_embedded_delete(
+ session: AsyncSession,
+ relationship: Any,
+ data: list[SQLModel],
+ link_model: type[_SQLModelMeta] | None,
+):
+ """Deletes the embedded records of the given parent records for the given relationship
+
+ Args:
+ session: the database session
+ relationship: the relationship whose embedded records are to be deleted for the given records
+ link_model: the model for the through table
+ """
+ relationship_props = relationship.property # type: RelationshipProperty
+ relationship_model = relationship_props.mapper.class_
+
+ parent_id_field_name = relationship_props.primaryjoin.left.name
+ parent_foreign_keys = [getattr(item, parent_id_field_name) for item in data]
+
+ if link_model is None:
+ reverse_foreign_key_field_name = relationship_props.primaryjoin.right.name
+ reverse_foreign_key_field = getattr(
+ relationship_model, reverse_foreign_key_field_name
+ )
+ await session.stream(
+ delete(relationship_model).where(
+ reverse_foreign_key_field.in_(parent_foreign_keys)
+ )
+ )
+ else:
+ reverse_foreign_key_field = getattr(link_model, parent_id_field_name)
+ await session.stream(
+ delete(link_model).where(reverse_foreign_key_field.in_(parent_foreign_keys))
+ )
+
+
+async def _get_nextid(session: AsyncSession, model: type[_SQLModelMeta]):
+ """Gets the next id generator for the given model
+
+ It returns a generator for the auto-incremented integer ID
+
+ Args:
+ session: the database session
+ model: the model under consideration
+
+ Returns:
+ a generator for the auto-incremented integer ID for the given model
+ """
+ # compute the next id auto-incremented
+ next_id = await session.scalar(func.max(model.id))
+ next_id = (next_id or 0) + 1
+ return next_id
+
+
+def _flatten_list(data: list[_T | list[_T]]) -> list[_T]:
+ """Flattens a list that may have lists of items at some indices
+
+ Args:
+ data: the list to flatten
+
+ Returns:
+ the flattened list
+ """
+ results = []
+ for item in data:
+ if isinstance(item, Iterable) and not isinstance(item, Mapping):
+ results += list(item)
+ else:
+ results.append(item)
+
+ return results
+
+
+def _as_list(value: Any) -> list:
+ """Wraps the value in a list if it is not an iterable
+
+ Args:
+ value: the value to wrap in a list if it is not one
+
+ Returns:
+ the value as a list if it is not already one
+ """
+ if isinstance(value, list):
+ return value
+ elif isinstance(value, Iterable) and not isinstance(value, Mapping):
+ return list(value)
+ return [value]
+
+
+def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict:
+ """Gets the updates that are affect only the relationships on this model
+
+ Args:
+ model: the model to be updated
+ updates: the dict of new values to updated on the matched records
+
+ Returns:
+ a dict with only updates concerning the relationships of the given model
+ """
+ return {k: v for k, v in updates.items() if k in model.__relational_fields__()}
+
+
+def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict:
+ """Gets the updates that do not affect relationships on this model
+
+ Args:
+ model: the model to be updated
+ updates: the dict of new values to updated on the matched records
+
+ Returns:
+ a dict with only updates that do not affect relationships on this model
+ """
+ return {k: v for k, v in updates.items() if k not in model.__relational_fields__()}
+
+
+async def _find(
+ session: AsyncSession,
+ model: type[_SQLModelMeta],
+ /,
+ *filters: _Filter,
+ skip: int = 0,
+ limit: int | None = None,
+ sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (),
+) -> list[_SQLModelMeta]:
+ """Finds the records that match the given filters
+
+ Args:
+ session: the sqlalchemy session
+ model: the model that is to be searched
+ filters: the filters to match
+ skip: number of records to ignore at the top of the returned results; default is 0
+ limit: maximum number of records to return; default is None.
+ sort: fields to sort by; default = None
+
+ Returns:
+ the records tha match the given filters
+ """
+ relations = list(model.__relational_fields__().values())
+
+ # eagerly load all relationships so that no validation errors occur due
+ # to missing session if there is an attempt to load them lazily later
+ eager_load_opts = [subqueryload(v) for v in relations]
+
+ filtered_relations = _get_filtered_relations(
+ filters=filters,
+ relations=relations,
+ )
+
+ # Note that we need to treat relations that are referenced in the filters
+ # differently from those that are not. This is because filtering basing on a relationship
+ # requires the use of an inner join. Yet an inner join automatically excludes rows
+ # that are have null for a given relationship.
+ #
+ # An outer join on the other hand would just return all the rows in the left table.
+ # We thus need to do an inner join on tables that are being filtered.
+ stmt = select(model)
+ for rel in filtered_relations:
+ stmt = stmt.join_from(model, rel)
+
+ cursor = await session.stream_scalars(
+ stmt.options(*eager_load_opts)
+ .where(*filters)
+ .limit(limit)
+ .offset(skip)
+ .order_by(*sort)
+ )
+ results = await cursor.all()
+ return list(results)
diff --git a/pyproject.toml b/pyproject.toml
index 582237b..63fa897 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,13 +33,13 @@ sql = [
"greenlet~=3.1.1",
]
mongo = ["beanie~=1.29.0"]
-redis = ["redis-om~=0.3.3"]
+redis = ["redis-om~=0.3.3,<0.3.4"]
all = [
"sqlmodel~=0.0.22",
"aiosqlite~=0.20.0",
"greenlet~=3.1.1",
"beanie~=1.29.0",
- "redis-om~=0.3.3",
+ "redis-om~=0.3.3,<0.3.4",
]
[project.urls]
diff --git a/tests/test_sql.py b/tests/test_sql.py
index c494958..51913fe 100644
--- a/tests/test_sql.py
+++ b/tests/test_sql.py
@@ -122,8 +122,6 @@ async def test_update_native(sql_store, inserted_sql_libs):
}
matches_query = lambda v: v.name.startswith("Bu") and v.address == _TEST_ADDRESS
- # in immediate response
- # NOTE: redis startswith/contains on single letters is not supported by redis
got = await sql_store.update(
SqlLibrary,
(SqlLibrary.name.startswith("Bu") & (SqlLibrary.address == _TEST_ADDRESS)),