diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ddb25cd..8b31786 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,7 +69,7 @@ jobs: strategy: matrix: python-version: [ "3.10", "3.11", "3.12", "3.13" ] - example_app: ["todos"] + example_app: ["todos", "blog"] services: mongodb: diff --git a/.gitignore b/.gitignore index 2f02596..f67c6d6 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.venv3_13 # Spyder project settings .spyderproject diff --git a/CHANGELOG.md b/CHANGELOG.md index a3fb163..f6648cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Changed + +- Allowed `EmbeddedJsonModel`'s to also contain `EmbeddedJsonModel` if there is need. +- Allowed extra key-word args to be passed to the `SQLModel()` +- Made the defining of fields by calling `Field()` or `Relationship()` mandatory. + This is because SQLModel's require this, and fail if this is not the case. +- Add support for "Many-to-One" relationships in SQL `insert()` implementation + +### Added + +- Added the [examples/blog](./examples/blog) example +- Add the `link_model` parameter to the `SQLModel()` function to add link models used as through tables + ## [0.1.6] - 2025-02-19 ### Changed diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e3e541b..0ab99fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,7 +77,7 @@ By contributing, you agree that your contributions will be licensed under its MI - Install the dependencies ```bash - pip install -r requirements.txt + pip install -r ."[all,test]" ``` - Run the pre-commit installation diff --git a/LIMITATIONS.md b/LIMITATIONS.md new file mode 100644 index 0000000..f522e1b --- /dev/null +++ b/LIMITATIONS.md @@ -0,0 +1,15 @@ +# Limitations + +## Filtering + +### Redis + +- Mongo-style regular expression filtering is not supported. + This is because native redis regular expression filtering is limited to the most basic text based search. + +## Update Operation + +### SQL + +- Even though one can update a model to theoretically infinite number of levels deep, + the returned results can only contain 1-level-deep nested models and no more. diff --git a/README.md b/README.md index 7fc459b..617fd6c 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,11 @@ libraries = await redis_store.delete( - [ ] Add documentation site +## Limitations + +This library is limited in some specific cases. +Read through the [`LIMITATIONS.md`](./LIMITATIONS.md) file for more. + ## Contributions Contributions are welcome. The docs have to maintained, the code has to be made cleaner, more idiomatic and faster, diff --git a/examples/blog/.gitignore b/examples/blog/.gitignore new file mode 100644 index 0000000..f640b3b --- /dev/null +++ b/examples/blog/.gitignore @@ -0,0 +1,188 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# vscode +.vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# environments +senv/ +renv/ +menv/ + +# pyenv +.python-version + +# sqlite test db +test.db \ No newline at end of file diff --git a/examples/blog/LICENSE b/examples/blog/LICENSE new file mode 100644 index 0000000..efbf346 --- /dev/null +++ b/examples/blog/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Martin Ahindura + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/blog/README.md b/examples/blog/README.md new file mode 100644 index 0000000..c0e72ea --- /dev/null +++ b/examples/blog/README.md @@ -0,0 +1,69 @@ +# Blog + +A simple Blog app based on [NQLStore](https://github.com/sopherapps/nqlstore) + +## Requirements + +- [Python +3.10](https://python.org) +- [NQLStore](https://github.com/sopherapps/nqlstore) +- [Redis stack (optional)](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/) +- [MongoDB (optional)](https://www.mongodb.com/products/self-managed/community-edition) +- [SQLite (optional)](https://www.sqlite.org/) + +## Getting Started + +- Ensure you have [Python +3.10](https://python.org) installed + +- Copy this repository and enter this folder + +```shell +git clone https://github.com/sopherapps/nqlstore.git +cd nqlstore/examples/blog +``` + +- Create a virtual env, activate it and install requirements + +```shell +python -m venv env +source env/bin/activate +pip install -r requirements.txt +``` + +- To use with [MongoDB](https://www.mongodb.com/try/download/community), install and start its server. +- To use with redis, install [redis stack](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/) + and start its server in another terminal. + +- Start the application, set the URL's for the database(s) to use. + Options are: + - `SQL_URL` for [SQLite](https://www.sqlite.org/). + - `MONGO_URL` (required) and `MONGO_DB` (default: "todos") for [MongoDB](https://www.mongodb.com/products/self-managed/community-edition) + - `REDIS_URL` for [Redis](https://redis.io/docs/latest/operate/oss_and_stack/install/install-stack/). + + _It is possible to use multiple databases at the same time. Just set multiple environment variables_ + +```shell +export SQL_URL="sqlite+aiosqlite:///test.db" +#export MONGO_URL="mongodb://localhost:27017" +#export MONGO_DB="testing" +export REDIS_URL="redis://localhost:6379/0" +fastapi dev main.py +``` + +## License + +Copyright (c) 2025 [Martin Ahindura](https://github.com/Tinitto) +Licensed under the [MIT License](./LICENSE) + +## Gratitude + +Glory be to God for His unmatchable love. + +> "As Jesus was on His way, the crowds almost crushed Him. +> And a woman was there who had been subject to bleeding +> for twelve years, but no one could heal her. +> She came up behind Him and touched the edge of His cloak, +> and immediately her bleeding stopped." +> +> -- Luke 8: 42-44 + +Buy Me A Coffee diff --git a/examples/blog/auth.py b/examples/blog/auth.py new file mode 100644 index 0000000..08de8cd --- /dev/null +++ b/examples/blog/auth.py @@ -0,0 +1,171 @@ +"""Deal with authentication for the app""" + +import os +from datetime import datetime, timedelta, timezone +from typing import Annotated, Any + +import jwt +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from models import MongoInternalAuthor, RedisInternalAuthor, SqlInternalAuthor +from passlib.context import CryptContext +from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep + +_ALGORITHM = "HS256" +_InternalAuthorModel = MongoInternalAuthor | SqlInternalAuthor | RedisInternalAuthor + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="signin") + + +def verify_password(plain_password, hashed_password): + """Verify that the passwords match + + Args: + plain_password: the plain password + hashed_password: the hashed password + + Returns: + True if the passwords match else False + """ + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password: str) -> str: + """Creates a hash from the given password + + Args: + password: the password to hash + + Returns: + the password hash + """ + return pwd_context.hash(password) + + +def get_jwt_secret() -> str: + """Gets the JWT secret from the environment""" + return os.environ.get("JWT_SECRET") + + +async def get_user( + sql_store: SqlStoreDep, + redis_store: RedisStoreDep, + mongo_store: MongoStoreDep, + query: dict[str, Any], +) -> _InternalAuthorModel: + """Gets the user instance that matches the given query + + Args: + sql_store: the SQL store from which to retrieve the user + redis_store: the redis store from which to retrieve the user + mongo_store: the mongo store from which to retrieve the user + query: the filter that the user must match + + Returns: + the matching user + + Raises: + HTTPException: Unauthorized + """ + try: + results = [] + if sql_store: + results += await sql_store.find(SqlInternalAuthor, query=query, limit=1) + + if mongo_store: + results += await mongo_store.find(MongoInternalAuthor, query=query, limit=1) + + if redis_store: + results += await redis_store.find(RedisInternalAuthor, query=query, limit=1) + + return results[0] + except (InvalidTokenError, IndexError): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def authenticate_user( + sql_store: SqlStoreDep, + redis_store: RedisStoreDep, + mongo_store: MongoStoreDep, + email: str, + password: str, +) -> _InternalAuthorModel: + """Authenticates the user of the given email and password + + Args: + sql_store: the SQL store from which to retrieve the user + redis_store: the redis store from which to retrieve the user + mongo_store: the mongo store from which to retrieve the user + email: the email of the user + password: the password of the user + + Returns: + the authenticated user or False + """ + user = await get_user( + sql_store=sql_store, + redis_store=redis_store, + mongo_store=mongo_store, + query={"email": {"$eq": email}}, + ) + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(secret_key: str, data: dict, ttl_minutes: float = 15) -> str: + """Creates an access token given a secret key + + Args: + secret_key: the JWT secret key for creating the JWT + data: the data to encode + ttl_minutes: the time to live in minutes + + Returns: + the access token + """ + expire = datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes) + encoded_jwt = jwt.encode({**data, "exp": expire}, secret_key, algorithm=_ALGORITHM) + return encoded_jwt + + +async def get_current_user( + sql_store: SqlStoreDep, + redis_store: RedisStoreDep, + mongo_store: MongoStoreDep, + secret_key: Annotated[str, Depends(get_jwt_secret)], + token: Annotated[str, Depends(oauth2_scheme)], +) -> _InternalAuthorModel: + """Gets the current user for the given token + + Args: + sql_store: the SQL store from which to retrieve the user + redis_store: the redis store from which to retrieve the user + mongo_store: the mongo store from which to retrieve the user + secret_key: the secret key for JWT decoding + token: the token for the current user + + Returns: + the user + + Raises: + HTTPException: Could not validate credentials + """ + payload = jwt.decode(token, secret_key, algorithms=[_ALGORITHM]) + email = payload.get("sub") + return await get_user( + sql_store=sql_store, + redis_store=redis_store, + mongo_store=mongo_store, + query={"email": {"$eq": email}}, + ) + + +# Dependencies +SecretKeyDep = Annotated[str, Depends(get_jwt_secret)] +CurrentUserDep = Annotated[_InternalAuthorModel, Depends(get_current_user)] diff --git a/examples/blog/conftest.py b/examples/blog/conftest.py new file mode 100644 index 0000000..8aae68e --- /dev/null +++ b/examples/blog/conftest.py @@ -0,0 +1,228 @@ +"""Fixtures for tests""" + +import os +from typing import Any + +import pytest +import pytest_asyncio +import pytest_mock +from fastapi.testclient import TestClient +from models import ( + MongoAuthor, + MongoComment, + MongoInternalAuthor, + MongoPost, + MongoTag, + RedisAuthor, + RedisComment, + RedisInternalAuthor, + RedisPost, + RedisTag, + SqlComment, + SqlInternalAuthor, + SqlPost, + SqlTag, + SqlTagLink, +) + +from nqlstore import MongoStore, RedisStore, SQLStore + +POST_LISTS: list[dict[str, Any]] = [ + { + "title": "School Work", + "content": "foo bar man the stuff", + }, + { + "title": "Home", + "tags": [ + {"title": "home"}, + {"title": "art"}, + ], + }, + {"title": "Boo", "content": "some random stuff", "tags": [{"title": "random"}]}, +] + +COMMENT_LIST: list[dict[str, Any]] = [ + { + "content": "Fake comment", + }, + { + "content": "Just wondering, who wrote this?", + }, + { + "content": "Mann, this is off the charts!", + }, + { + "content": "Woo hoo", + }, + { + "content": "Not cool. Not cool at all.", + }, +] + +AUTHOR: dict[str, Any] = { + "name": "John Doe", + "email": "johndoe@example.com", + "password": "password123", +} + +_SQL_DB = "test.db" +_SQL_URL = f"sqlite+aiosqlite:///{_SQL_DB}" +_REDIS_URL = "redis://localhost:6379/0" +_MONGO_URL = "mongodb://localhost:27017" +_MONGO_DB = "testing" +ACCESS_TOKEN = "some-token" + + +@pytest.fixture +def mocked_auth(mocker: pytest_mock.MockerFixture): + """Mocks the auth to always return the AUTHOR as valid""" + mocker.patch("jwt.encode", return_value=ACCESS_TOKEN) + mocker.patch("jwt.decode", return_value={"sub": AUTHOR["email"]}) + mocker.patch("auth.pwd_context.verify", return_value=True) + yield + + +@pytest.fixture +def client_with_sql(mocked_auth): + """The fastapi test client when SQL is enabled""" + _reset_env() + os.environ["SQL_URL"] = _SQL_URL + + from main import app + + yield TestClient(app) + _reset_env() + + +@pytest_asyncio.fixture +async def client_with_redis(mocked_auth): + """The fastapi test client when redis is enabled""" + _reset_env() + os.environ["REDIS_URL"] = _REDIS_URL + + from main import app + + yield TestClient(app) + _reset_env() + + +@pytest.fixture +def client_with_mongo(mocked_auth): + """The fastapi test client when mongodb is enabled""" + _reset_env() + + os.environ["MONGO_URL"] = _MONGO_URL + os.environ["MONGO_DB"] = _MONGO_DB + + from main import app + + yield TestClient(app) + _reset_env() + + +@pytest_asyncio.fixture() +async def sql_store(mocked_auth): + """The sql store stored in memory""" + store = SQLStore(uri=_SQL_URL) + + await store.register( + [ + # SqlAuthor, + SqlTag, + SqlTagLink, + SqlPost, + SqlComment, + SqlInternalAuthor, + ] + ) + # insert default user + await store.insert(SqlInternalAuthor, [AUTHOR]) + yield store + + # clean up + os.remove(_SQL_DB) + + +@pytest_asyncio.fixture() +async def mongo_store(mocked_auth): + """The mongodb store. Requires a running instance of mongodb""" + import pymongo + + store = MongoStore(uri=_MONGO_URL, database=_MONGO_DB) + await store.register( + [ + MongoAuthor, + MongoTag, + MongoPost, + MongoComment, + MongoInternalAuthor, + ] + ) + + # insert default user + await store.insert(MongoInternalAuthor, [AUTHOR]) + + yield store + + # clean up + client = pymongo.MongoClient(_MONGO_URL) # type: ignore + client.drop_database(_MONGO_DB) + + +@pytest_asyncio.fixture +async def redis_store(mocked_auth): + """The redis store. Requires a running instance of redis stack""" + import redis + + store = RedisStore(_REDIS_URL) + await store.register( + [ + RedisAuthor, + RedisTag, + RedisPost, + RedisComment, + RedisInternalAuthor, + ] + ) + # insert default user + await store.insert(RedisInternalAuthor, [AUTHOR]) + + yield store + + # clean up + client = redis.Redis("localhost", 6379, 0) + client.flushall() + + +@pytest_asyncio.fixture() +async def sql_posts(sql_store: SQLStore): + """A list of posts in the sql store""" + records = await sql_store.insert(SqlPost, POST_LISTS) + yield records + + +@pytest_asyncio.fixture() +async def mongo_posts(mongo_store: MongoStore): + """A list of posts in the mongo store""" + + records = await mongo_store.insert(MongoPost, POST_LISTS) + yield records + + +@pytest_asyncio.fixture() +async def redis_posts(redis_store: RedisStore): + """A list of posts in the redis store""" + records = await redis_store.insert(RedisPost, POST_LISTS) + yield records + + +def _reset_env(): + """Resets the environment variables available to the app""" + os.environ["SQL_URL"] = "" + os.environ["REDIS_URL"] = "" + os.environ["MONGO_URL"] = "" + os.environ["MONGO_DB"] = "testing" + os.environ["JWT_SECRET"] = ( + "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" + ) diff --git a/examples/blog/main.py b/examples/blog/main.py new file mode 100644 index 0000000..7390f45 --- /dev/null +++ b/examples/blog/main.py @@ -0,0 +1,296 @@ +import logging +from contextlib import asynccontextmanager +from typing import Annotated + +from auth import ( + CurrentUserDep, + SecretKeyDep, + authenticate_user, + create_access_token, + get_password_hash, +) +from fastapi import Depends, FastAPI, HTTPException, Query, status +from fastapi.security import OAuth2PasswordRequestForm +from models import MongoPost, RedisPost, SqlInternalAuthor, SqlPost +from pydantic import BaseModel +from schemas import InternalAuthor, PartialPost, Post, TokenResponse +from stores import MongoStoreDep, RedisStoreDep, SqlStoreDep, clear_stores + +_ACCESS_TOKEN_EXPIRE_MINUTES = 30 + + +@asynccontextmanager +async def lifespan(app_: FastAPI): + clear_stores() + yield + clear_stores() + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/signup") +async def signup( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + payload: InternalAuthor, +): + """Signup a new user""" + results = [] + payload_dict = payload.model_dump(exclude_unset=True) + payload_dict["password"] = get_password_hash(payload_dict["password"]) + + try: + if sql: + results += await sql.insert(SqlInternalAuthor, [payload_dict]) + + if redis: + results += await redis.insert(RedisPost, [payload_dict]) + + if mongo: + results += await mongo.insert(MongoPost, [payload_dict]) + + result = results[0].model_dump(mode="json") + return result + except Exception as exp: + logging.error(exp) + raise exp + + +@app.post("/signin") +async def signin( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + secret_key: SecretKeyDep, + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], +) -> TokenResponse: + user = await authenticate_user( + sql_store=sql, + redis_store=redis, + mongo_store=mongo, + email=form_data.username, + password=form_data.password, + ) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + access_token = create_access_token( + secret_key=secret_key, + data={"sub": user.username}, + ttl_minutes=_ACCESS_TOKEN_EXPIRE_MINUTES, + ) + return TokenResponse(access_token=access_token, token_type="bearer") + + +@app.get("/posts") +async def search( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + title: str | None = Query(None), + author: str | None = Query(None), + tag: str | None = Query(None), +): + """Searches for posts by title, author or tag""" + + results = [] + query_dict: dict[str, str] = { + k: v + for k, v in {"title": title, "author.name": author, "tags.title": tag}.items() + if v + } + query = {k: {"$regex": f".*{v}.*", "$options": "i"} for k, v in query_dict.items()} + try: + if sql: + results += await sql.find(SqlPost, query=query) + + if redis: + # redis's regex search is not mature so we use its full text search + # Unfortunately, redis search does not permit us to search fields that are arrays. + redis_query = [ + ( + (_get_redis_field(RedisPost, k) == f"{v}") + if k == "tags.title" + else (_get_redis_field(RedisPost, k) % f"*{v}*") + ) + for k, v in query_dict.items() + ] + results += await redis.find(RedisPost, *redis_query) + + if mongo: + results += await mongo.find(MongoPost, query=query) + except Exception as exp: + logging.error(exp) + raise exp + + return [item.model_dump(mode="json") for item in results] + + +@app.get("/posts/{id_}") +async def get_one( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + id_: int | str, +): + """Get post by id""" + results = [] + query = {"id": {"$eq": id_}} + + try: + if sql: + results += await sql.find(SqlPost, query=query, limit=1) + + if redis: + results += await redis.find(RedisPost, query=query, limit=1) + + if mongo: + results += await mongo.find(MongoPost, query=query, limit=1) + except Exception as exp: + logging.error(exp) + raise exp + + try: + return results[0].model_dump(mode="json") + except IndexError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + +@app.post("/posts") +async def create_one( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + current_user: CurrentUserDep, + payload: Post, +): + """Create a post""" + results = [] + payload_dict = payload.model_dump(exclude_unset=True) + payload_dict["author"] = current_user.model_dump() + + try: + if sql: + results += await sql.insert(SqlPost, [payload_dict]) + + if redis: + results += await redis.insert(RedisPost, [payload_dict]) + + if mongo: + results += await mongo.insert(MongoPost, [payload_dict]) + + result = results[0].model_dump(mode="json") + return result + except Exception as exp: + logging.error(exp) + raise exp + + +@app.put("/posts/{id_}") +async def update_one( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + current_user: CurrentUserDep, + id_: int | str, + payload: PartialPost, +): + """Update a post""" + results = [] + query = {"id": {"$eq": id_}} + updates = payload.model_dump(exclude_unset=True) + user_dict = current_user.model_dump() + + if "comments" in updates: + # just resetting the author of all comments to current user. + # This is probably logically wrong. + # It is just here for illustration + updates["comments"] = [ + {**item, "author": user_dict} for item in updates["comments"] + ] + + try: + if sql: + results += await sql.update(SqlPost, query=query, updates=updates) + + if redis: + results += await redis.update(RedisPost, query=query, updates=updates) + + if mongo: + results += await mongo.update(MongoPost, query=query, updates=updates) + except Exception as exp: + logging.error(exp) + raise exp + + try: + return results[0].model_dump(mode="json") + except IndexError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + +@app.delete("/posts/{id_}") +async def delete_one( + sql: SqlStoreDep, + redis: RedisStoreDep, + mongo: MongoStoreDep, + id_: int | str, +): + """Delete a post""" + results = [] + query = {"id": {"$eq": id_}} + + try: + if sql: + results += await sql.delete(SqlPost, query=query) + + if redis: + results += await redis.delete(RedisPost, query=query) + + if mongo: + results += await mongo.delete(MongoPost, query=query) + except Exception as exp: + logging.error(exp) + raise exp + + try: + return results[0].model_dump(mode="json") + except IndexError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + + +def _get_redis_field(model: type[BaseModel], path: str): + """Retrieves the Field at the given path, which may or may not be dotted + + Args: + path: the path to the field where dots signify relations; example books.title + model: the parent model + + Returns: + the FieldInfo at the given path + + Raises: + ValueError: no field '{path}' found on '{parent}' + """ + path_segments = path.split(".") + current_parent = model + + field = None + for idx, part in enumerate(path_segments): + field = getattr(current_parent, part) + try: + current_parent = field + except AttributeError as exp: + if idx == len(path_segments) - 1: + break + raise exp + + if field is None: + raise ValueError(f"no field '{path}' found on '{model}'") + + return field diff --git a/examples/blog/models.py b/examples/blog/models.py new file mode 100644 index 0000000..e03688b --- /dev/null +++ b/examples/blog/models.py @@ -0,0 +1,64 @@ +"""Models that are saved in storage""" + +from schemas import Author, Comment, InternalAuthor, Post, Tag, TagLink + +from nqlstore import ( + EmbeddedJsonModel, + EmbeddedMongoModel, + JsonModel, + MongoModel, + SQLModel, +) + +# mongo models +MongoInternalAuthor = MongoModel("MongoInternalAuthor", InternalAuthor) +MongoAuthor = EmbeddedMongoModel("MongoAuthor", Author) +MongoComment = EmbeddedMongoModel( + "MongoComment", Comment, embedded_models={"author": MongoAuthor} +) +MongoTag = EmbeddedMongoModel("MongoTag", Tag) +MongoPost = MongoModel( + "MongoPost", + Post, + embedded_models={ + "author": MongoAuthor | None, + "comments": list[MongoComment], + "tags": list[MongoTag], + }, +) + + +# redis models +RedisInternalAuthor = JsonModel("RedisInternalAuthor", InternalAuthor) +RedisAuthor = EmbeddedJsonModel("RedisAuthor", Author) +RedisComment = EmbeddedJsonModel( + "RedisComment", Comment, embedded_models={"author": RedisAuthor} +) +RedisTag = EmbeddedJsonModel("RedisTag", Tag) +RedisPost = JsonModel( + "RedisPost", + Post, + embedded_models={ + "author": RedisAuthor | None, + "comments": list[RedisComment], + "tags": list[RedisTag], + }, +) + +# sqlite models +SqlInternalAuthor = SQLModel("SqlInternalAuthor", InternalAuthor) +SqlComment = SQLModel( + "SqlComment", Comment, relationships={"author": SqlInternalAuthor | None} +) +SqlTagLink = SQLModel("SqlTagLink", TagLink) +SqlTag = SQLModel("SqlTag", Tag) +SqlPost = SQLModel( + "SqlPost", + Post, + relationships={ + "author": SqlInternalAuthor | None, + "comments": list[SqlComment], + "tags": list[SqlTag], + }, + link_models={"tags": SqlTagLink}, +) diff --git a/examples/blog/requirements.txt b/examples/blog/requirements.txt new file mode 100644 index 0000000..c546830 --- /dev/null +++ b/examples/blog/requirements.txt @@ -0,0 +1,10 @@ +fastapi[standard]~=0.115.8 +nqlstore[all] +black~=25.1.0 +isort~=6.0.0 +pytest~=8.3.4 +pytest-asyncio~=0.25.3 +passlib[bcrypt]~=1.7.4 +pyjwt~=2.10.1 +pytest-freezegun~=0.4.2 +pytest-mock~=3.14.0 diff --git a/examples/blog/schemas.py b/examples/blog/schemas.py new file mode 100644 index 0000000..8fac64f --- /dev/null +++ b/examples/blog/schemas.py @@ -0,0 +1,96 @@ +"""Schemas for the application""" + +from datetime import datetime + +from pydantic import BaseModel +from utils import Partial, current_timestamp + +from nqlstore import Field, Relationship + + +class Author(BaseModel): + """The author as returned to the user""" + + name: str = Field(index=True, full_text_search=True) + + +class InternalAuthor(Author): + """The author as saved in database""" + + password: str = Field() + email: str = Field(index=True) + + +class Post(BaseModel): + """The post""" + + title: str = Field(index=True, full_text_search=True) + content: str | None = Field(default="") + author_id: int | None = Field( + default=None, + foreign_key="sqlinternalauthor.id", + disable_on_mongo=True, + disable_on_redis=True, + ) + author: Author | None = Relationship(default=None) + comments: list["Comment"] = Relationship(default=[]) + tags: list["Tag"] = Relationship(default=[], link_model="TagLink") + created_at: str = Field(index=True, default_factory=current_timestamp) + updated_at: str = Field(index=True, default_factory=current_timestamp) + + +class Comment(BaseModel): + """The comment on a post""" + + post_id: int | None = Field( + default=None, + foreign_key="sqlpost.id", + disable_on_mongo=True, + disable_on_redis=True, + ) + content: str | None = Field(default="") + author_id: int | None = Field( + default=None, + foreign_key="sqlinternalauthor.id", + disable_on_mongo=True, + disable_on_redis=True, + ) + author: Author | None = Relationship(default=None) + created_at: str = Field(index=True, default_factory=current_timestamp) + updated_at: str = Field(index=True, default_factory=current_timestamp) + + +class TagLink(BaseModel): + """The SQL-only join table between tags and posts""" + + post_id: int | None = Field( + default=None, + foreign_key="sqlpost.id", + primary_key=True, + disable_on_mongo=True, + disable_on_redis=True, + ) + tag_id: int | None = Field( + default=None, + foreign_key="sqltag.id", + primary_key=True, + disable_on_mongo=True, + disable_on_redis=True, + ) + + +class Tag(BaseModel): + """The tags to help searching for posts""" + + title: str = Field(index=True, unique=True) + + +class TokenResponse(BaseModel): + """HTTP-only response""" + + access_token: str + token_type: str + + +# Partial models +PartialPost = Partial("PartialPost", Post) diff --git a/examples/blog/stores.py b/examples/blog/stores.py new file mode 100644 index 0000000..0c6b575 --- /dev/null +++ b/examples/blog/stores.py @@ -0,0 +1,111 @@ +"""Module containing the registry of stores at runtime""" + +import os +from typing import Annotated + +from fastapi import Depends +from models import ( # SqlAuthor, + MongoAuthor, + MongoComment, + MongoInternalAuthor, + MongoPost, + MongoTag, + RedisAuthor, + RedisComment, + RedisInternalAuthor, + RedisPost, + RedisTag, + SqlComment, + SqlInternalAuthor, + SqlPost, + SqlTag, + SqlTagLink, +) + +from nqlstore import MongoStore, RedisStore, SQLStore + +_STORES: dict[str, MongoStore | RedisStore | SQLStore] = {} + + +async def get_redis_store() -> RedisStore | None: + """Gets the redis store whose URL is specified in the environment""" + global _STORES + + if redis_url := os.environ.get("REDIS_URL"): + try: + return _STORES[redis_url] + except KeyError: + store = RedisStore(uri=redis_url) + await store.register( + [ + RedisAuthor, + RedisTag, + RedisPost, + RedisComment, + RedisInternalAuthor, + ] + ) + _STORES[redis_url] = store + return store + + +async def get_sql_store() -> SQLStore | None: + """Gets the sql store whose URL is specified in the environment""" + global _STORES + + if sql_url := os.environ.get("SQL_URL"): + try: + return _STORES[sql_url] + except KeyError: + store = SQLStore(uri=sql_url) + await store.register( + [ + # SqlAuthor, + SqlTag, + SqlTagLink, + SqlPost, + SqlComment, + SqlInternalAuthor, + ] + ) + _STORES[sql_url] = store + return store + + +async def get_mongo_store() -> MongoStore | None: + """Gets the mongo store whose URL and database are specified in the environment""" + global _STORES + + if mongo_url := os.environ.get("MONGO_URL"): + mongo_db = os.environ.get("MONGO_DB", "todos") + mongo_full_url = f"{mongo_url}/{mongo_db}" + + try: + return _STORES[mongo_full_url] + except KeyError: + store = MongoStore(uri=mongo_url, database=mongo_db) + await store.register( + [ + MongoAuthor, + MongoTag, + MongoPost, + MongoComment, + MongoInternalAuthor, + ] + ) + _STORES[mongo_full_url] = store + return store + + +def clear_stores(): + """Clears the registry of stores + + Important for clean up + """ + global _STORES + _STORES.clear() + + +SqlStoreDep = Annotated[SQLStore | None, Depends(get_sql_store)] +RedisStoreDep = Annotated[RedisStore | None, Depends(get_redis_store)] +MongoStoreDep = Annotated[MongoStore | None, Depends(get_mongo_store)] diff --git a/examples/blog/test_main.py b/examples/blog/test_main.py new file mode 100644 index 0000000..cc15e0a --- /dev/null +++ b/examples/blog/test_main.py @@ -0,0 +1,613 @@ +from datetime import datetime +from typing import Any + +import pytest +from bson import ObjectId +from conftest import ACCESS_TOKEN, AUTHOR, COMMENT_LIST, POST_LISTS +from fastapi.testclient import TestClient +from main import MongoPost, RedisPost, SqlPost + +from nqlstore import MongoStore, RedisStore, SQLStore + +_TITLE_SEARCH_TERMS = ["ho", "oo", "work"] +_TAG_SEARCH_TERMS = ["art", "om"] +_HEADERS = {"Authorization": f"Bearer {ACCESS_TOKEN}"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("post", POST_LISTS) +async def test_create_sql_post( + client_with_sql: TestClient, sql_store: SQLStore, post: dict, freezer +): + """POST to /posts creates a post in sql and returns it""" + timestamp = datetime.now().isoformat() + with client_with_sql as client: + response = client.post("/posts", json=post, headers=_HEADERS) + + got = response.json() + post_id = got["id"] + raw_tags = post.get("tags", []) + resp_tags = got["tags"] + expected = { + "id": post_id, + "title": post["title"], + "content": post.get("content", ""), + "author": {"id": 1, **AUTHOR}, + "author_id": 1, + "tags": [ + { + **raw, + "id": resp["id"], + } + for raw, resp in zip(raw_tags, resp_tags) + ], + "comments": [], + "created_at": timestamp, + "updated_at": timestamp, + } + + db_query = {"id": {"$eq": post_id}} + db_results = await sql_store.find(SqlPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump() + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("post", POST_LISTS) +async def test_create_redis_post( + client_with_redis: TestClient, + redis_store: RedisStore, + post: dict, + freezer, +): + """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() + with client_with_redis as client: + response = client.post("/posts", json=post, headers=_HEADERS) + + got = response.json() + post_id = got["id"] + raw_tags = post.get("tags", []) + resp_tags = got["tags"] + expected = { + "id": post_id, + "title": post["title"], + "content": post.get("content", ""), + "author": {**got["author"], **AUTHOR}, + "pk": post_id, + "tags": [ + { + **raw, + "id": resp["id"], + "pk": resp["pk"], + } + for raw, resp in zip(raw_tags, resp_tags) + ], + "comments": [], + "created_at": timestamp, + "updated_at": timestamp, + } + + db_query = {"id": {"$eq": post_id}} + db_results = await redis_store.find(RedisPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("post", POST_LISTS) +async def test_create_mongo_post( + client_with_mongo: TestClient, + mongo_store: MongoStore, + post: dict, + freezer, +): + """POST to /posts creates a post in redis and returns it""" + timestamp = datetime.now().isoformat() + with client_with_mongo as client: + response = client.post("/posts", json=post, headers=_HEADERS) + + got = response.json() + post_id = got["id"] + raw_tags = post.get("tags", []) + expected = { + "id": post_id, + "title": post["title"], + "content": post.get("content", ""), + "author": {"name": AUTHOR["name"]}, + "tags": raw_tags, + "comments": [], + "created_at": timestamp, + "updated_at": timestamp, + } + + db_query = {"_id": {"$eq": ObjectId(post_id)}} + db_results = await mongo_store.find(MongoPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_update_sql_post( + client_with_sql: TestClient, + sql_store: SQLStore, + sql_posts: list[SqlPost], + index: int, + freezer, +): + """PUT to /posts/{id} updates the sql post of given id and returns updated version""" + timestamp = datetime.now().isoformat() + with client_with_sql as client: + post = sql_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) + id_ = post.id + update = { + **post_dict, + "title": "some other title", + "tags": [ + *post_dict["tags"], + {"title": "another one"}, + {"title": "another one again"}, + ], + "comments": [*post_dict["comments"], *COMMENT_LIST[index:]], + } + + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) + + got = response.json() + expected = { + **post.model_dump(mode="json"), + **update, + "comments": [ + { + **raw, + "id": final["id"], + "post_id": final["post_id"], + "author_id": 1, + "created_at": timestamp, + "updated_at": timestamp, + } + for raw, final in zip(update["comments"], got["comments"]) + ], + "tags": [ + { + **raw, + "id": final["id"], + } + for raw, final in zip(update["tags"], got["tags"]) + ], + } + db_query = {"id": {"$eq": id_}} + db_results = await sql_store.find(SqlPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_update_redis_post( + client_with_redis: TestClient, + redis_store: RedisStore, + redis_posts: list[RedisPost], + index: int, + freezer, +): + """PUT to /posts/{id} updates the redis post of given id and returns updated version""" + timestamp = datetime.now().isoformat() + with client_with_redis as client: + post = redis_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) + id_ = post.id + update = { + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), + {"title": "another one"}, + {"title": "another one again"}, + ], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], + } + + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) + + got = response.json() + expected = { + **post.model_dump(mode="json"), + **update, + "comments": [ + { + **raw, + "id": final["id"], + "author": final["author"], + "pk": final["pk"], + "created_at": timestamp, + "updated_at": timestamp, + } + for raw, final in zip(update["comments"], got["comments"]) + ], + "tags": [ + { + **raw, + "id": final["id"], + "pk": final["pk"], + } + for raw, final in zip(update["tags"], got["tags"]) + ], + } + db_query = {"id": {"$eq": id_}} + db_results = await redis_store.find(RedisPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + expected_in_db = { + **expected, + "tags": [ + { + **raw, + "id": final["id"], + "pk": final["pk"], + } + for raw, final in zip(expected["tags"], record_in_db["tags"]) + ], + "comments": [ + { + **raw, + "id": final["id"], + "pk": final["pk"], + } + for raw, final in zip(expected["comments"], record_in_db["comments"]) + ], + } + + assert got == expected + assert record_in_db == expected_in_db + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_update_mongo_post( + client_with_mongo: TestClient, + mongo_store: MongoStore, + mongo_posts: list[MongoPost], + index: int, + freezer, +): + """PUT to /posts/{id} updates the mongo post of given id and returns updated version""" + timestamp = datetime.now().isoformat() + with client_with_mongo as client: + post = mongo_posts[index] + post_dict = post.model_dump(mode="json", exclude_none=True, exclude_unset=True) + id_ = post.id + update = { + "title": "some other title", + "tags": [ + *post_dict.get("tags", []), + {"title": "another one"}, + {"title": "another one again"}, + ], + "comments": [*post_dict.get("comments", []), *COMMENT_LIST[index:]], + } + + response = client.put(f"/posts/{id_}", json=update, headers=_HEADERS) + + got = response.json() + expected = { + **post.model_dump(mode="json"), + **update, + "comments": [ + { + **raw, + "author": final["author"], + "created_at": timestamp, + "updated_at": timestamp, + } + for raw, final in zip(update["comments"], got["comments"]) + ], + } + db_query = {"_id": {"$eq": id_}} + db_results = await mongo_store.find(MongoPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_delete_sql_post( + client_with_sql: TestClient, + sql_store: SQLStore, + sql_posts: list[SqlPost], + index: int, +): + """DELETE /posts/{id} deletes the sql post of given id and returns deleted version""" + with client_with_sql as client: + post = sql_posts[index] + id_ = post.id + + response = client.delete(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"id": {"$eq": id_}} + db_results = await sql_store.find(SqlPost, query=db_query, limit=1) + + assert got == expected + assert db_results == [] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_delete_redis_post( + client_with_redis: TestClient, + redis_store: RedisStore, + redis_posts: list[RedisPost], + index: int, +): + """DELETE /posts/{id} deletes the redis post of given id and returns deleted version""" + with client_with_redis as client: + post = redis_posts[index] + id_ = post.id + + response = client.delete(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"id": {"$eq": id_}} + db_results = await redis_store.find(RedisPost, query=db_query, limit=1) + + assert got == expected + assert db_results == [] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_delete_mongo_post( + client_with_mongo: TestClient, + mongo_store: MongoStore, + mongo_posts: list[MongoPost], + index: int, +): + """DELETE /posts/{id} deletes the mongo post of given id and returns deleted version""" + with client_with_mongo as client: + post = mongo_posts[index] + id_ = post.id + + response = client.delete(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"_id": {"$eq": id_}} + db_results = await mongo_store.find(MongoPost, query=db_query, limit=1) + + assert got == expected + assert db_results == [] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_read_one_sql_post( + client_with_sql: TestClient, + sql_store: SQLStore, + sql_posts: list[SqlPost], + index: int, +): + """GET /posts/{id} gets the sql post of given id""" + with client_with_sql as client: + post = sql_posts[index] + id_ = post.id + + response = client.get(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"id": {"$eq": id_}} + db_results = await sql_store.find(SqlPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_read_one_redis_post( + client_with_redis: TestClient, + redis_store: RedisStore, + redis_posts: list[RedisPost], + index: int, +): + """GET /posts/{id} gets the redis post of given id""" + with client_with_redis as client: + post = redis_posts[index] + id_ = post.id + + response = client.get(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"id": {"$eq": id_}} + db_results = await redis_store.find(RedisPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("index", range(len(POST_LISTS))) +async def test_read_one_mongo_post( + client_with_mongo: TestClient, + mongo_store: MongoStore, + mongo_posts: list[MongoPost], + index: int, +): + """GET /posts/{id} gets the mongo post of given id""" + with client_with_mongo as client: + post = mongo_posts[index] + id_ = post.id + + response = client.get(f"/posts/{id_}") + + got = response.json() + expected = post.model_dump(mode="json") + + db_query = {"_id": {"$eq": id_}} + db_results = await mongo_store.find(MongoPost, query=db_query, limit=1) + record_in_db = db_results[0].model_dump(mode="json") + + assert got == expected + assert record_in_db == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS) +async def test_search_sql_by_title( + client_with_sql: TestClient, + sql_store: SQLStore, + sql_posts: list[SqlPost], + q: str, +): + """GET /posts?title={} gets all sql posts with title containing search item""" + with client_with_sql as client: + response = client.get(f"/posts?title={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") for v in sql_posts if q in v.title.lower() + ] + + assert got == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS) +async def test_search_redis_by_title( + client_with_redis: TestClient, + redis_store: RedisStore, + redis_posts: list[RedisPost], + q: str, +): + """GET /posts?title={} gets all redis posts with title containing search item""" + with client_with_redis as client: + response = client.get(f"/posts?title={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") for v in redis_posts if q in v.title.lower() + ] + + assert got == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", _TITLE_SEARCH_TERMS) +async def test_search_mongo_by_title( + client_with_mongo: TestClient, + mongo_store: MongoStore, + mongo_posts: list[MongoPost], + q: str, +): + """GET /posts?title={} gets all mongo posts with title containing search item""" + with client_with_mongo as client: + response = client.get(f"/posts?title={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") for v in mongo_posts if q in v.title.lower() + ] + + assert got == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS) +async def test_search_sql_by_tag( + client_with_sql: TestClient, + sql_store: SQLStore, + sql_posts: list[SqlPost], + q: str, +): + """GET /posts?tag={} gets all sql posts with tag containing search item""" + with client_with_sql as client: + response = client.get(f"/posts?tag={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") + for v in sql_posts + if any([q in tag.title.lower() for tag in v.tags]) + ] + + assert got == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", ["random", "another one", "another one again"]) +async def test_search_redis_by_tag( + client_with_redis: TestClient, + redis_store: RedisStore, + redis_posts: list[RedisPost], + q: str, +): + """GET /posts?tag={} gets all redis posts with tag containing search item. Partial searches nit supported.""" + with client_with_redis as client: + response = client.get(f"/posts?tag={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") + for v in redis_posts + if any([q in tag.title.lower() for tag in v.tags]) + ] + + assert got == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("q", _TAG_SEARCH_TERMS) +async def test_search_mongo_by_tag( + client_with_mongo: TestClient, + mongo_store: MongoStore, + mongo_posts: list[MongoPost], + q: str, +): + """GET /posts?tag={} gets all mongo posts with tag containing search item""" + with client_with_mongo as client: + response = client.get(f"/posts?tag={q}") + + got = response.json() + expected = [ + v.model_dump(mode="json") + for v in mongo_posts + if any([q in tag.title.lower() for tag in v.tags]) + ] + + assert got == expected + + +def _get_id(item: Any) -> Any: + """Gets the id of the given record + + Args: + item: the record whose id is to be obtained + + Returns: + the id of the record + """ + try: + return item.id + except AttributeError: + return item.pk diff --git a/examples/blog/utils.py b/examples/blog/utils.py new file mode 100644 index 0000000..4a917c0 --- /dev/null +++ b/examples/blog/utils.py @@ -0,0 +1,61 @@ +"""Some random utilities for the app""" + +import copy +import sys +from datetime import datetime +from typing import Any, Literal, Optional, TypeVar, get_args + +from pydantic import BaseModel, create_model +from pydantic.main import IncEx + +from nqlstore._field import FieldInfo + +_T = TypeVar("_T", bound=BaseModel) + + +def current_timestamp() -> str: + """Gets the current timestamp as an timezone naive ISO format string + + Returns: + string of the current datetime + """ + return datetime.now().isoformat() + + +def Partial(name: str, model: type[_T]) -> type[_T]: + """Creates a partial schema from another schema, with all fields optional + + Args: + name: the name of the model + model: the original model + + Returns: + A new model with all its fields optional + """ + fields = { + k: (_make_optional(v.annotation), None) + for k, v in model.model_fields.items() # type: str, FieldInfo + } + + return create_model( + name, + # module of the calling function + __module__=sys._getframe(1).f_globals["__name__"], + __doc__=model.__doc__, + __base__=(model,), + **fields, + ) + + +def _make_optional(type_: type) -> type: + """Makes a type optional if not optional + + Args: + type_: the type to make optional + + Returns: + the optional type + """ + if type(None) in get_args(type_): + return type_ + return type_ | None diff --git a/examples/todos/schemas.py b/examples/todos/schemas.py index 3fa0d3c..d765343 100644 --- a/examples/todos/schemas.py +++ b/examples/todos/schemas.py @@ -9,7 +9,7 @@ class TodoList(BaseModel): """A list of Todos""" name: str = Field(index=True, full_text_search=True) - description: str | None = None + description: str | None = Field(default=None) todos: list["Todo"] = Relationship(back_populates="parent", default=[]) diff --git a/nqlstore/_compat.py b/nqlstore/_compat.py index 770aebc..f22a9ba 100644 --- a/nqlstore/_compat.py +++ b/nqlstore/_compat.py @@ -41,9 +41,16 @@ sql imports; and their default if sqlmodel is missing """ try: - from sqlalchemy import Column, Table + from sqlalchemy import Column, Table, func + from sqlalchemy.dialects.postgresql import insert as pg_insert + from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import InstrumentedAttribute, RelationshipProperty, subqueryload + from sqlalchemy.orm import ( + InstrumentedAttribute, + RelationshipDirection, + RelationshipProperty, + subqueryload, + ) from sqlalchemy.orm.exc import DetachedInstanceError from sqlalchemy.sql._typing import ( _ColumnExpressionArgument, @@ -58,6 +65,7 @@ from sqlmodel.main import IncEx, NoArgAnyCallable, OnDeleteType from sqlmodel.main import RelationshipInfo as _RelationshipInfo except ImportError: + import types from typing import Mapping, Optional, Sequence from typing import Set from typing import Set as _ColumnExpressionArgument @@ -75,14 +83,16 @@ OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] Column = Any create_async_engine = lambda *a, **k: dict(**k) - delete = insert = select = update = create_async_engine + pg_insert = sqlite_insert = delete = insert = select = update = create_async_engine AsyncSession = Any - RelationshipProperty = Set + RelationshipDirection = RelationshipProperty = Set Table = Set InstrumentedAttribute = Set subqueryload = lambda *a, **kwargs: dict(**kwargs) DetachedInstanceError = RuntimeError IncEx = Set[Any] | dict + func = types.ModuleType("func") + func.max = lambda *a, **kwargs: dict(**kwargs) class _SqlFieldInfo(_FieldInfo): ... diff --git a/nqlstore/_field.py b/nqlstore/_field.py index c980347..026ef3e 100644 --- a/nqlstore/_field.py +++ b/nqlstore/_field.py @@ -394,6 +394,7 @@ def get_field_definitions( schema: type[ModelT], embedded_models: dict[str, Type] | None = None, relationships: dict[str, Type] | None = None, + link_models: dict[str, Type] | None = None, is_for_redis: bool = False, is_for_mongo: bool = False, is_for_sql: bool = False, @@ -404,6 +405,8 @@ def get_field_definitions( schema: the model schema class embedded_models: the map of embedded models as : relationships: the map of relationships as : + link_models: a map of :Model class for all link (through) + tables in many-to-many relationships is_for_redis: whether the definitions are for redis or not is_for_mongo: whether the definitions are for mongo or not is_for_sql: whether the definitions are for sql or not @@ -417,16 +420,24 @@ def get_field_definitions( if relationships is None: relationships = {} + if link_models is None: + link_models = {} + fields = {} for field_name, field in schema.model_fields.items(): # type: str, FieldInfo class_field_definition = _get_class_field_definition(field) - if is_for_redis and getattr(class_field_definition, "disable_on_redis", False): + if not isinstance(class_field_definition, (FieldInfo, RelationshipInfo)): + raise TypeError( + f"field '{schema.__name__}.{field_name}' was not initialized with a {Field.__name__}() or {Relationship.__name__}()" + ) + + if is_for_redis and class_field_definition.disable_on_redis: continue - if is_for_mongo and getattr(class_field_definition, "disable_on_mongo", False): + if is_for_mongo and class_field_definition.disable_on_mongo: continue - if is_for_sql and getattr(class_field_definition, "disable_on_sql", False): + if is_for_sql and class_field_definition.disable_on_sql: continue field_type = field.annotation @@ -441,6 +452,7 @@ def get_field_definitions( field_type = relationships[field_name] # redefine the class so that SQLModel can redo its thing field_info = class_field_definition + field_info.link_model = link_models.get(field_name) fields[field_name] = (field_type, field_info) return fields diff --git a/nqlstore/_redis.py b/nqlstore/_redis.py index 2022ad3..fc18866 100644 --- a/nqlstore/_redis.py +++ b/nqlstore/_redis.py @@ -259,6 +259,7 @@ class _EmbeddedJsonModelMeta(_EmbeddedJsonModel, abc.ABC): """Base model for all EmbeddedJsonModels. Helpful with typing""" id: str | None + __embedded_models__: dict @classmethod def set_db(cls, db: Redis): @@ -275,9 +276,26 @@ def set_db(cls, db: Redis): except AttributeError: cls.Meta.database = db + # cache the embedded models on the class + embedded_models = getattr(cls, "__embedded_models__", None) + if embedded_models is None: + embedded_models = [ + model + for field in cls.model_fields.values() # type: FieldInfo + for model in _get_embed_models(field.annotation) + ] + setattr(cls, "__embedded_models__", embedded_models) + + # set db on embedded models also + for model in cls.__embedded_models__: + model.set_db(db) + def EmbeddedJsonModel( - name: str, schema: type[ModelT], / + name: str, + schema: type[ModelT], + /, + embedded_models: dict[str, Type] = None, ) -> type[_EmbeddedJsonModelMeta] | type[ModelT]: """Creates a new EmbeddedJsonModel for the given schema for redis @@ -288,11 +306,14 @@ def EmbeddedJsonModel( Args: name: the name of the model schema: the schema from which the model is to be made + embedded_models: a dict of embedded models of : annotation Returns: a EmbeddedJsonModel model class with the given name """ - fields = get_field_definitions(schema, embedded_models=None, is_for_redis=True) + fields = get_field_definitions( + schema, embedded_models=embedded_models, is_for_redis=True + ) return create_model( name, diff --git a/nqlstore/_sql.py b/nqlstore/_sql.py index 26b6ee1..18346bf 100644 --- a/nqlstore/_sql.py +++ b/nqlstore/_sql.py @@ -1,8 +1,9 @@ """SQL implementation""" +import copy import sys from collections.abc import Mapping, MutableMapping -from typing import Any, Dict, Iterable, Literal, TypeVar, Union +from typing import Any, Dict, Iterable, Literal, Sequence, TypeVar, Union from pydantic import create_model from pydantic.main import ModelT @@ -14,6 +15,7 @@ DetachedInstanceError, IncEx, InstrumentedAttribute, + RelationshipDirection, RelationshipProperty, Table, _ColumnExpressionArgument, @@ -21,8 +23,11 @@ _SQLModel, create_async_engine, delete, + func, insert, + pg_insert, select, + sqlite_insert, subqueryload, update, ) @@ -30,8 +35,86 @@ from .query.parsers import QueryParser from .query.selectors import QuerySelector -_T = TypeVar("_T", bound=_SQLModel) _Filter = _ColumnExpressionArgument[bool] | bool +_T = TypeVar("_T") + + +class _SQLModelMeta(_SQLModel): + """The base class for all SQL models""" + + id: int | None = Field(default=None, primary_key=True) + __rel_field_cache__: dict = {} + """dict of (name, Field) that have associated relationships""" + + @classmethod + def __relational_fields__(cls) -> dict[str, Any]: + """dict of (name, Field) that have associated relationships""" + + cls_fullname = f"{cls.__module__}.{cls.__qualname__}" + try: + return cls.__rel_field_cache__[cls_fullname] + except KeyError: + value = { + k: v + for k, v in cls.__mapper__.all_orm_descriptors.items() + if isinstance(v.property, RelationshipProperty) + } + cls.__rel_field_cache__[cls_fullname] = value + return value + + def model_dump( + self, + *, + mode: Union[Literal["json", "python"], str] = "python", + include: IncEx = None, + exclude: IncEx = None, + context: Union[Dict[str, Any], None] = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: Union[bool, Literal["none", "warn", "error"]] = True, + serialize_as_any: bool = False, + ) -> Dict[str, Any]: + data = super().model_dump( + mode=mode, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + relations_mappers = self.__class__.__relational_fields__() + for k, field in relations_mappers.items(): + if exclude is None or k not in exclude: + try: + value = getattr(self, k, None) + except DetachedInstanceError: + # ignore lazy loaded values + continue + + if value is not None or not exclude_none: + data[k] = _serialize_embedded( + value, + field=field, + mode=mode, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + + return data class SQLStore(BaseStore): @@ -41,24 +124,29 @@ def __init__(self, uri: str, parser: QueryParser | None = None, **kwargs): super().__init__(uri, parser=parser, **kwargs) self._engine = create_async_engine(uri, **kwargs) - async def register(self, models: list[type[_T]], checkfirst: bool = True): - tables = [v.__table__ for v in models] + async def register( + self, models: list[type[_SQLModelMeta]], checkfirst: bool = True + ): + tables = [v.__table__ for v in models if hasattr(v, "__table__")] async with self._engine.begin() as conn: await conn.run_sync( _SQLModel.metadata.create_all, tables=tables, checkfirst=checkfirst ) async def insert( - self, model: type[_T], items: Iterable[_T | dict], **kwargs - ) -> list[_T]: + self, + model: type[_SQLModelMeta], + items: Iterable[_SQLModelMeta | dict], + **kwargs, + ) -> list[_SQLModelMeta]: parsed_items = [ v if isinstance(v, model) else model.model_validate(v) for v in items ] - relations_mapper = _get_relations_mapper(model) + relations_mapper = model.__relational_fields__() async with AsyncSession(self._engine) as session: - stmt = insert(model).returning(model) - cursor = await session.stream_scalars(stmt, parsed_items) + insert_stmt = await _get_insert_func(session, model=model) + cursor = await session.stream_scalars(insert_stmt, parsed_items) results = await cursor.all() result_ids = [v.id for v in results] @@ -72,254 +160,116 @@ async def insert( for idx, record in enumerate(items): parent = results[idx] raw_value = _get_key_or_prop(record, k) - embedded_value = _parse_embedded(raw_value, field, parent) - if isinstance(embedded_value, Iterable): - embedded_values += embedded_value - elif isinstance(embedded_value, _SQLModel): + embedded_value = _embed_value(parent, field, raw_value) + + if isinstance(embedded_value, _SQLModel): embedded_values.append(embedded_value) + elif isinstance(embedded_value, Iterable): + embedded_values += embedded_value # insert the related items if len(embedded_values) > 0: field_model = field.property.mapper.class_ - embed_stmt = insert(field_model).returning(field_model) + embed_stmt = await _get_insert_func(session, model=field_model) await session.stream_scalars(embed_stmt, embedded_values) + # update the updated parents + session.add_all(results) + await session.commit() refreshed_results = await self.find(model, model.id.in_(result_ids)) return list(refreshed_results) async def find( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, skip: int = 0, limit: int | None = None, sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - - relations = _get_relations(model) - - # eagerly load all relationships so that no validation errors occur due - # to missing session if there is an attempt to load them lazily later - eager_load_opts = [subqueryload(v) for v in relations] - - filtered_relations = _get_filtered_relations( - filters=filters, - relations=relations, + return await _find( + session, model, *filters, skip=skip, limit=limit, sort=sort ) - # Note that we need to treat relations that are referenced in the filters - # differently from those that are not. This is because filtering basing on a relationship - # requires the use of an inner join. Yet an inner join automatically excludes rows - # that are have null for a given relationship. - # - # An outer join on the other hand would just return all the rows in the left table. - # We thus need to do an inner join on tables that are being filtered. - stmt = select(model) - for rel in filtered_relations: - stmt = stmt.join_from(model, rel) - - cursor = await session.stream_scalars( - stmt.options(*eager_load_opts) - .where(*filters) - .limit(limit) - .offset(skip) - .order_by(*sort) - ) - results = await cursor.all() - return list(results) - async def update( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, updates: dict | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: + updates = copy.deepcopy(updates) async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, - ) - - # dealing with nested models in the update - relations_mapper = _get_relations_mapper(model) - embedded_updates = {} - for k in relations_mapper: - try: - embedded_updates[k] = updates.pop(k) - except KeyError: - pass - - stmt = ( - update(model) - .where(*non_rel_filters, *rel_filters) - .values(**updates) - .returning(model.__table__) + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + + # Let's update the fields that are not embedded model fields + # and return the affected results + results = await _update_non_embedded_fields( + session, + model, + *non_relational_filters, + *relational_filters, + updates=updates, ) - - cursor = await session.stream(stmt) - raw_results = await cursor.fetchall() - results = [model.model_validate(row._mapping) for row in raw_results] result_ids = [v.id for v in results] - for k, v in embedded_updates.items(): - field = relations_mapper[k] - field_props = field.property - field_model = field_props.mapper.class_ - # fk = foreign key - fk_field_name = field_props.primaryjoin.right.name - fk_field = getattr(field_model, fk_field_name) - parent_id_field = field_props.primaryjoin.left.name - - # get the foreign keys to use in resetting all affected - # relationships; - # get parsed embedded values so that they can replace - # the old relations. - # Note: this operation is strictly replace, not patch - embedded_values = [] - fk_values = [] - for parent in results: - embedded_value = _parse_embedded(v, field, parent) - if isinstance(embedded_value, Iterable): - embedded_values += embedded_value - fk_values.append(getattr(parent, parent_id_field)) - elif isinstance(embedded_value, _SQLModel): - embedded_values.append(embedded_value) - fk_values.append(getattr(parent, parent_id_field)) - - # insert the related items - if len(embedded_values) > 0: - # Reset the relationship; delete all other related items - # Currently, this operation replaces all past relations - reset_stmt = delete(field_model).where(fk_field.in_(fk_values)) - await session.stream(reset_stmt) - - # insert the latest changes - embed_stmt = insert(field_model).returning(field_model) - await session.stream_scalars(embed_stmt, embedded_values) - + # Let's update the embedded fields also + await _update_embedded_fields( + session, model=model, records=results, updates=updates + ) await session.commit() - return await self.find(model, model.id.in_(result_ids)) + + refreshed_results = await self.find(model, model.id.in_(result_ids)) + return refreshed_results async def delete( self, - model: type[_T], + model: type[_SQLModelMeta], *filters: _Filter, query: QuerySelector | None = None, **kwargs, - ) -> list[_T]: + ) -> list[_SQLModelMeta]: async with AsyncSession(self._engine) as session: if query: filters = (*filters, *self._parser.to_sql(model, query=query)) deleted_items = await self.find(model, *filters) - # Construct filters that have sub queries - relations = _get_relations(model) - rel_filters, non_rel_filters = _sieve_rel_from_non_rel_filters( - filters=filters, - relations=relations, - ) - rel_filters = _to_subquery_based_filters( - model=model, - rel_filters=rel_filters, - relations=relations, - ) + relational_filters = _get_relational_filters(model, filters) + non_relational_filters = _get_non_relational_filters(model, filters) + exec_options = {} - if len(rel_filters) > 0: + if len(relational_filters) > 0: exec_options = {"is_delete_using": True} await session.stream( delete(model) - .where(*non_rel_filters, *rel_filters) + .where(*non_relational_filters, *relational_filters) .execution_options(**exec_options), ) await session.commit() return deleted_items -class _SQLModelMeta(_SQLModel): - """The base class for all SQL models""" - - id: int | None = Field(default=None, primary_key=True) - - def model_dump( - self, - *, - mode: Union[Literal["json", "python"], str] = "python", - include: IncEx = None, - exclude: IncEx = None, - context: Union[Dict[str, Any], None] = None, - by_alias: bool = False, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: Union[bool, Literal["none", "warn", "error"]] = True, - serialize_as_any: bool = False, - ) -> Dict[str, Any]: - data = super().model_dump( - mode=mode, - include=include, - exclude=exclude, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - relations_mappers = _get_relations_mapper(self.__class__) - for k, field in relations_mappers.items(): - if exclude is None or k not in exclude: - try: - value = getattr(self, k, None) - except DetachedInstanceError: - # ignore lazy loaded values - continue - - if value is not None or not exclude_none: - data[k] = _serialize_embedded( - value, - field=field, - mode=mode, - context=context, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - round_trip=round_trip, - warnings=warnings, - serialize_as_any=serialize_as_any, - ) - - return data - - def SQLModel( name: str, schema: type[ModelT], /, relationships: dict[str, type[Any] | type[Union[Any]]] = None, + link_models: dict[str, type[Any]] = None, + table: bool = True, + **kwargs: Any, ) -> type[_SQLModelMeta] | type[ModelT]: """Creates a new SQLModel for the given schema for redis @@ -331,55 +281,30 @@ def SQLModel( name: the name of the model schema: the schema from which the model is to be made relationships: a map of :annotation for all relationships + link_models: a map of :Model class for all link (through) + tables in many-to-many relationships + table: whether this model should have a table in the database or not; + default = True + kwargs: key-word args to pass to the SQLModel when defining it Returns: a SQLModel model class with the given name """ - fields = get_field_definitions(schema, relationships=relationships, is_for_sql=True) + fields = get_field_definitions( + schema, relationships=relationships, link_models=link_models, is_for_sql=True + ) return create_model( name, # module of the calling function __module__=sys._getframe(1).f_globals["__name__"], __doc__=schema.__doc__, - __cls_kwargs__={"table": True}, + __cls_kwargs__={"table": table, **kwargs}, __base__=(_SQLModelMeta,), **fields, ) -def _get_relations(model: type[_SQLModel]): - """Gets all the relational fields of the given model - - Args: - model: the SQL model to inspect - - Returns: - list of Fields that have associated relationships - """ - return [ - v - for v in model.__mapper__.all_orm_descriptors.values() - if isinstance(v.property, RelationshipProperty) - ] - - -def _get_relations_mapper(model: type[_SQLModel]) -> dict[str, Any]: - """Gets all the relational fields with their names of the given model - - Args: - model: the SQL model to inspect - - Returns: - dict of (name, Field) that have associated relationships - """ - return { - k: v - for k, v in model.__mapper__.all_orm_descriptors.items() - if isinstance(v.property, RelationshipProperty) - } - - def _get_filtered_tables(filters: Iterable[_Filter]) -> list[Table]: """Retrieves the tables that have been referenced in the filters @@ -413,31 +338,51 @@ def _get_filtered_relations( return [rel for rel in relations if rel.property.target in filtered_tables] -def _sieve_rel_from_non_rel_filters( - filters: Iterable[_Filter], relations: Iterable[InstrumentedAttribute[Any]] -) -> tuple[list[_Filter], list[_Filter]]: - """Separates relational filters from non-relational ones +def _get_relational_filters( + model: type[_SQLModelMeta], + filters: Iterable[_Filter], +) -> list[_Filter]: + """Gets the filters that are concerned with relationships on this model + + The filters returned are in subquery form since 'update' and 'delete' + in sqlalchemy do not have join and the only way to attach these filters + to the model is through sub queries Args: + model: the model under consideration filters: the tuple of filters to inspect - relations: all relations present on the model Returns: - tuple(rel, non_rel) where rel = list of relational filters, - and non_rel = non-relational filters + list of filters that are concerned with relationships on this model """ - rel_targets = [v.property.target for v in relations] - rel = [] - non_rel = [] - - for filter_ in filters: - operands = filter_.get_children() - if any([getattr(v, "table", None) in rel_targets for v in operands]): - rel.append(filter_) - else: - non_rel.append(filter_) + relationships = list(model.__relational_fields__().values()) + targets = [v.property.target for v in relationships] + plain_filters = [ + item + for item in filters + if any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] + return _to_subquery_based_filters(model, plain_filters, relationships) + - return rel, non_rel +def _get_non_relational_filters( + model: type[_SQLModelMeta], filters: Iterable[_Filter] +) -> list[_Filter]: + """Gets the filters that are NOT concerned with relationships on this model + + Args: + model: the model under consideration + filters: the tuple of filters to inspect + + Returns: + list of filters that are NOT concerned with relationships on this model + """ + targets = [v.property.target for v in model.__relational_fields__().values()] + return [ + item + for item in filters + if not any([getattr(v, "table", None) in targets for v in item.get_children()]) + ] def _to_subquery_based_filters( @@ -514,45 +459,116 @@ def _with_value(obj: dict | Any, field: str, value: Any) -> Any: return obj -def _parse_embedded( - value: Iterable[dict | Any] | dict | Any, field: Any, parent: _SQLModel +def _embed_value( + parent: _SQLModel, + relationship: Any, + value: Iterable[dict | Any] | dict | Any, ) -> Iterable[_SQLModel] | _SQLModel | None: - """Parses embedded items that can be a single item or many into SQLModels + """Embeds in place a given value into the parent basing on the given relationship + + Note that the parent itself is changed to include the value Args: - value: the value to parse - field: the field on which these embedded items are - parent: the parent SQLModel to which this value is attached + parent: the model that contains the given relationships + relationship: the given relationship + value: the values correspond to the related field Returns: - An iterable of SQLModel instances or a single SQLModel instance - or None if value is None + the embedded record(s) """ if value is None: return None - props = field.property # type: RelationshipProperty[Any] + props = relationship.property # type: RelationshipProperty[Any] wrapper_type = props.collection_class - field_model = props.mapper.class_ - fk_field = props.primaryjoin.right.name - parent_id_field = props.primaryjoin.left.name - fk_value = getattr(parent, parent_id_field) + relationship_model = props.mapper.class_ + parent_foreign_key_field = props.primaryjoin.right.name + direction = props.direction + + if direction == RelationshipDirection.MANYTOONE: + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = value.get(related_value_id_key) + # update the foreign key value in the parent + setattr(parent, parent_foreign_key_field, parent_foreign_key_value) + # create child + child = relationship_model.model_validate(value) + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(value, dict): + nested_related_value = value.get(field_name) + else: + nested_related_value = getattr(value, field_name) + + nested_related_records = _embed_value( + parent=child, relationship=field_type, value=nested_related_value + ) + setattr(child, field_name, nested_related_records) + + return child - if issubclass(wrapper_type, (list, tuple, set)): + elif direction == RelationshipDirection.ONETOMANY: + related_value_id_key = props.primaryjoin.left.name + parent_foreign_key_value = getattr(parent, related_value_id_key) # add a foreign key values to link back to parent - return wrapper_type( - [ - field_model.model_validate(_with_value(v, fk_field, fk_value)) - for v in value - ] - ) - elif wrapper_type is None: - # add a foreign key value to link back to parent - linked_value = _with_value(value, fk_field, fk_value) - return field_model.model_validate(linked_value) + if issubclass(wrapper_type, (list, tuple, set)): + embedded_records = [] + for v in value: + child = relationship_model.model_validate( + _with_value(v, parent_foreign_key_field, parent_foreign_key_value) + ) + + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + + nested_related_records = _embed_value( + parent=child, + relationship=field_type, + value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) + + elif direction == RelationshipDirection.MANYTOMANY: + if issubclass(wrapper_type, (list, tuple, set)): + embedded_records = [] + for v in value: + child = relationship_model.model_validate(v) + + # update nested relationships + for ( + field_name, + field_type, + ) in relationship_model.__relational_fields__().items(): + if isinstance(v, dict): + nested_related_value = v.get(field_name) + else: + nested_related_value = getattr(v, field_name) + nested_related_records = _embed_value( + parent=child, + relationship=field_type, + value=nested_related_value, + ) + setattr(child, field_name, nested_related_records) + + embedded_records.append(child) + + return wrapper_type(embedded_records) raise NotImplementedError( - f"relationship of type annotation {wrapper_type} not supported yet" + f"relationship {direction} of type annotation {wrapper_type} not supported yet" ) @@ -576,12 +592,370 @@ def _serialize_embedded( props = field.property # type: RelationshipProperty[Any] wrapper_type = props.collection_class - if issubclass(wrapper_type, (list, tuple, set)): + if wrapper_type is None: + return value.model_dump(**kwargs) + elif issubclass(wrapper_type, (list, tuple, set)): # add a foreign key values to link back to parent return wrapper_type([v.model_dump(**kwargs) for v in value]) - elif wrapper_type is None: - return value.model_dump(**kwargs) raise NotImplementedError( f"relationship of type annotation {wrapper_type} not supported yet" ) + + +async def _get_insert_func(session: AsyncSession, model: type[_SQLModelMeta]): + """Gets the insert statement for the given session + + Args: + session: the async session connecting to the database + model: the model for which the insert statement is to be obtained + + Returns: + the insert function + """ + conn = await session.connection() + dialect = conn.dialect + dialect_name = dialect.name + + native_insert_func = insert + + if dialect_name == "sqlite": + native_insert_func = sqlite_insert + if dialect_name == "postgresql": + native_insert_func = pg_insert + + # insert the embedded items + try: + # PostgreSQL and SQLite support on_conflict_do_nothing + return native_insert_func(model).on_conflict_do_nothing().returning(model) + except AttributeError: + # MySQL supports prefix("IGNORE") + # Other databases might fail at this point + return ( + native_insert_func(model) + .prefix_with("IGNORE", dialect="mysql") + .returning(model) + ) + + +async def _update_non_embedded_fields( + session: AsyncSession, model: type[_SQLModelMeta], *filters: _Filter, updates: dict +): + """Updates only the non-embedded fields of the model + + It ignores any relationships and returns the updated results + + Args: + session: the sqlalchemy session + model: the model to be updated + filters: the filters against which to match the records that are to be updated + updates: the updates to add to each matched record + + Returns: + the updated records + """ + non_embedded_updates = _get_non_relational_updates(model, updates) + if len(non_embedded_updates) == 0: + # if we supplied an empty update dict to update, + # there would be an error + return await _find(session, model, *filters) + + stmt = update(model).where(*filters).values(**non_embedded_updates).returning(model) + cursor = await session.stream_scalars(stmt) + return await cursor.fetchall() + + +async def _update_embedded_fields( + session: AsyncSession, + model: type[_SQLModelMeta], + records: list[_SQLModelMeta], + updates: dict, +): + """Updates only the embedded fields of the model for the given records + + It ignores any fields in the `updates` dict that are not for embedded models + Note: this operation is replaces the values of the embedded fields with the new values + passed in the `updates` dictionary as opposed to patching the pre-existing values. + + Args: + session: the sqlalchemy session + model: the model to be updated + records: the db records to update + updates: the updates to add to each record + """ + embedded_updates = _get_relational_updates(model, updates) + relations_mapper = model.__relational_fields__() + for k, v in embedded_updates.items(): + relationship = relations_mapper[k] + link_model = model.__sqlmodel_relationships__[k].link_model + + # this does a replace operation; i.e. removes old values and replaces them with the updates + await _bulk_embedded_delete( + session, relationship=relationship, data=records, link_model=link_model + ) + await _bulk_embedded_insert( + session, + relationship=relationship, + data=records, + link_model=link_model, + payload=v, + ) + # FIXME: Should the added records be updated with their embedded values? + # update the updated parents + session.add_all(records) + + +async def _bulk_embedded_insert( + session: AsyncSession, + relationship: Any, + data: list[_SQLModelMeta], + link_model: type[_SQLModelMeta] | None, + payload: Iterable[dict] | dict, +) -> Sequence[_SQLModelMeta] | None: + """Inserts the payload into the data following the given relationship + + It updates the database also + + Args: + session: the database session + relationship: the relationship the payload has with the data's schema + link_model: the model for the through table + payload: the payload to merge into each record in the data + + Returns: + the updated data including the embedded data in each record + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parsed_embedded_records = [_embed_value(v, relationship, payload) for v in data] + + insert_stmt = await _get_insert_func(session, model=relationship_model) + embedded_cursor = await session.stream_scalars( + insert_stmt, _flatten_list(parsed_embedded_records) + ) + embedded_db_records = await embedded_cursor.all() + + parent_embedded_map = [ + (parent, embedded_db_records[idx : idx + len(_as_list(raw_embedded))]) + for idx, (parent, raw_embedded) in enumerate(zip(data, parsed_embedded_records)) + ] + + # insert through table values + await _bulk_insert_through_table_data( + session, + relationship=relationship, + link_model=link_model, + parent_embedded_map=parent_embedded_map, + ) + + return data + + +async def _bulk_insert_through_table_data( + session: AsyncSession, + relationship: Any, + link_model: type[_SQLModelMeta] | None, + parent_embedded_map: list[tuple[_SQLModelMeta, list[_SQLModelMeta]]], +): + """Inserts the link records into the through-table represented by the link_model + + Args: + session: the database session + relationship: the relationship the embedded records are based on + link_model: the model for the through table + parent_embedded_map: the list of tuples of parent and its associated embedded db records + """ + if link_model is not None: + relationship_props = relationship.property # type: RelationshipProperty + child_id_field_name = relationship_props.secondaryjoin.left.name + parent_id_field_name = relationship_props.primaryjoin.left.name + child_fk_field_name = relationship_props.secondaryjoin.right.name + parent_fk_field_name = relationship_props.primaryjoin.right.name + + link_raw_values = [ + { + parent_fk_field_name: getattr(parent, parent_id_field_name), + child_fk_field_name: getattr(child, child_id_field_name), + } + for parent, children in parent_embedded_map + for child in children + ] + + next_id = await _get_nextid(session, link_model) + link_model_values = [ + link_model(id=next_id + idx, **v) for idx, v in enumerate(link_raw_values) + ] + + insert_stmt = await _get_insert_func(session, model=link_model) + await session.stream_scalars(insert_stmt, link_model_values) + + +async def _bulk_embedded_delete( + session: AsyncSession, + relationship: Any, + data: list[SQLModel], + link_model: type[_SQLModelMeta] | None, +): + """Deletes the embedded records of the given parent records for the given relationship + + Args: + session: the database session + relationship: the relationship whose embedded records are to be deleted for the given records + link_model: the model for the through table + """ + relationship_props = relationship.property # type: RelationshipProperty + relationship_model = relationship_props.mapper.class_ + + parent_id_field_name = relationship_props.primaryjoin.left.name + parent_foreign_keys = [getattr(item, parent_id_field_name) for item in data] + + if link_model is None: + reverse_foreign_key_field_name = relationship_props.primaryjoin.right.name + reverse_foreign_key_field = getattr( + relationship_model, reverse_foreign_key_field_name + ) + await session.stream( + delete(relationship_model).where( + reverse_foreign_key_field.in_(parent_foreign_keys) + ) + ) + else: + reverse_foreign_key_field = getattr(link_model, parent_id_field_name) + await session.stream( + delete(link_model).where(reverse_foreign_key_field.in_(parent_foreign_keys)) + ) + + +async def _get_nextid(session: AsyncSession, model: type[_SQLModelMeta]): + """Gets the next id generator for the given model + + It returns a generator for the auto-incremented integer ID + + Args: + session: the database session + model: the model under consideration + + Returns: + a generator for the auto-incremented integer ID for the given model + """ + # compute the next id auto-incremented + next_id = await session.scalar(func.max(model.id)) + next_id = (next_id or 0) + 1 + return next_id + + +def _flatten_list(data: list[_T | list[_T]]) -> list[_T]: + """Flattens a list that may have lists of items at some indices + + Args: + data: the list to flatten + + Returns: + the flattened list + """ + results = [] + for item in data: + if isinstance(item, Iterable) and not isinstance(item, Mapping): + results += list(item) + else: + results.append(item) + + return results + + +def _as_list(value: Any) -> list: + """Wraps the value in a list if it is not an iterable + + Args: + value: the value to wrap in a list if it is not one + + Returns: + the value as a list if it is not already one + """ + if isinstance(value, list): + return value + elif isinstance(value, Iterable) and not isinstance(value, Mapping): + return list(value) + return [value] + + +def _get_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that are affect only the relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates concerning the relationships of the given model + """ + return {k: v for k, v in updates.items() if k in model.__relational_fields__()} + + +def _get_non_relational_updates(model: type[_SQLModelMeta], updates: dict) -> dict: + """Gets the updates that do not affect relationships on this model + + Args: + model: the model to be updated + updates: the dict of new values to updated on the matched records + + Returns: + a dict with only updates that do not affect relationships on this model + """ + return {k: v for k, v in updates.items() if k not in model.__relational_fields__()} + + +async def _find( + session: AsyncSession, + model: type[_SQLModelMeta], + /, + *filters: _Filter, + skip: int = 0, + limit: int | None = None, + sort: tuple[_ColumnExpressionOrStrLabelArgument[Any]] = (), +) -> list[_SQLModelMeta]: + """Finds the records that match the given filters + + Args: + session: the sqlalchemy session + model: the model that is to be searched + filters: the filters to match + skip: number of records to ignore at the top of the returned results; default is 0 + limit: maximum number of records to return; default is None. + sort: fields to sort by; default = None + + Returns: + the records tha match the given filters + """ + relations = list(model.__relational_fields__().values()) + + # eagerly load all relationships so that no validation errors occur due + # to missing session if there is an attempt to load them lazily later + eager_load_opts = [subqueryload(v) for v in relations] + + filtered_relations = _get_filtered_relations( + filters=filters, + relations=relations, + ) + + # Note that we need to treat relations that are referenced in the filters + # differently from those that are not. This is because filtering basing on a relationship + # requires the use of an inner join. Yet an inner join automatically excludes rows + # that are have null for a given relationship. + # + # An outer join on the other hand would just return all the rows in the left table. + # We thus need to do an inner join on tables that are being filtered. + stmt = select(model) + for rel in filtered_relations: + stmt = stmt.join_from(model, rel) + + cursor = await session.stream_scalars( + stmt.options(*eager_load_opts) + .where(*filters) + .limit(limit) + .offset(skip) + .order_by(*sort) + ) + results = await cursor.all() + return list(results) diff --git a/pyproject.toml b/pyproject.toml index 582237b..63fa897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,13 @@ sql = [ "greenlet~=3.1.1", ] mongo = ["beanie~=1.29.0"] -redis = ["redis-om~=0.3.3"] +redis = ["redis-om~=0.3.3,<0.3.4"] all = [ "sqlmodel~=0.0.22", "aiosqlite~=0.20.0", "greenlet~=3.1.1", "beanie~=1.29.0", - "redis-om~=0.3.3", + "redis-om~=0.3.3,<0.3.4", ] [project.urls] diff --git a/tests/test_sql.py b/tests/test_sql.py index c494958..51913fe 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -122,8 +122,6 @@ async def test_update_native(sql_store, inserted_sql_libs): } matches_query = lambda v: v.name.startswith("Bu") and v.address == _TEST_ADDRESS - # in immediate response - # NOTE: redis startswith/contains on single letters is not supported by redis got = await sql_store.update( SqlLibrary, (SqlLibrary.name.startswith("Bu") & (SqlLibrary.address == _TEST_ADDRESS)),