Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.15.7"
rev: "v0.15.8"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
]

autosummary_generate = False
smartquotes = False


autosectionlabel_prefix_document = True
Expand Down
14 changes: 5 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ uuid = ["uuid-utils"]

[dependency-groups]
benchmarks = ["sqlalchemy[asyncio]", "psutil", "types-psutil", "duckdb-engine>=0.17.0"]
build = ["bump-my-version", "hatch-mypyc", "pydantic-settings"]
build = ["bump-my-version", "hatch-mypyc", "mypy>=1.19.1", "pydantic-settings"]
dev = [
{ include-group = "extras" },
{ include-group = "lint" },
Expand Down Expand Up @@ -115,7 +115,7 @@ extras = [
"dishka",
]
lint = [
"mypy>=1.13.0",
"mypy>=1.19.1",
"pre-commit>=3.5.0",
"pyright>=1.1.386",
"ruff>=0.7.1",
Expand Down Expand Up @@ -176,7 +176,7 @@ packages = ["sqlspec"]


[tool.hatch.build.targets.wheel.hooks.mypyc]
dependencies = ["hatch-mypyc", "hatch-cython"]
dependencies = ["hatch-mypyc", "hatch-cython", "mypy>=1.19.1"]
enable-by-default = false
exclude = [
"tests/**", # Test files
Expand All @@ -187,21 +187,17 @@ exclude = [
"sqlspec/**/_typing.py", # Type aliases (mypyc-incompatible)
"sqlspec/config.py", # Main config
"sqlspec/extensions/**", # All extensions
"sqlspec/dialects/**/*.py", # Keep SQLGlot dialect subclasses interpreted
"sqlspec/**/__init__.py", # Init files (usually just imports)
"sqlspec/protocols.py", # Protocol definitions
"sqlspec/adapters/mock/**", # Mock adapter (testing only)
"sqlspec/migrations/commands.py", # Migration command CLI (dynamic imports)
"sqlspec/data_dictionary/_loader.py", # Loader relies on __file__ which fails in compiled modules
"sqlspec/dialects/postgres/_pgvector.py", # Dialect shell patches compiled sqlglot parser/generator registries
"sqlspec/dialects/postgres/_paradedb.py", # Dialect shell patches compiled sqlglot parser/generator registries
"sqlspec/dialects/spanner/_spanner.py", # Dialect shell patches compiled sqlglot parser hooks and post-processes DDL
"sqlspec/dialects/spanner/_spangres.py", # Dialect shell patches compiled sqlglot parser hooks
"sqlspec/extensions/fastapi/providers.py", # Uses SingletonMeta metaclass
"sqlspec/extensions/litestar/providers.py", # Uses SingletonMeta metaclass
"sqlspec/adapters/**/data_dictionary.py", # Cross-module inheritance causes mypyc segfaults
"sqlspec/observability/_formatting.py", # Inherits from non-compiled logging.Formatter
"sqlspec/utils/arrow_helpers.py", # Arrow operations cause segfaults when compiled
"sqlspec/storage/backends/_iterators.py", # Async __anext__ + asyncio.to_thread causes mypyc segfault
]
include = [
"sqlspec/core/**/*.py", # Core module
Expand All @@ -212,7 +208,7 @@ include = [
"sqlspec/driver/**/*.py", # Driver module
"sqlspec/storage/registry.py", # Safe storage registry/runtime routing
"sqlspec/storage/errors.py", # Safe storage error normalization
"sqlspec/storage/backends/base.py", # Storage backend runtime base classes (iterators in _iterators.py)
"sqlspec/storage/backends/base.py", # Storage backend runtime base classes
"sqlspec/data_dictionary/**/*.py", # Data dictionary mixin (required for adapter inheritance)
"sqlspec/adapters/**/core.py", # Adapter compiled helpers
"sqlspec/adapters/**/type_converter.py", # All adapters type converters
Expand Down
6 changes: 6 additions & 0 deletions sqlspec/adapters/psycopg/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from psycopg import AsyncConnection, AsyncCursor, Connection, Cursor
from psycopg.rows import DictRow as PsycopgDictRow
from psycopg.sql import SQL as PsycopgSQL # noqa: N811
from psycopg.sql import Composed as PsycopgComposed
from psycopg.sql import Identifier as PsycopgIdentifier

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -203,8 +206,11 @@ async def __aexit__(
"PsycopgAsyncCursor",
"PsycopgAsyncRawCursor",
"PsycopgAsyncSessionContext",
"PsycopgComposed",
"PsycopgDictRow",
"PsycopgIdentifier",
"PsycopgPipelineDriver",
"PsycopgSQL",
"PsycopgSyncConnection",
"PsycopgSyncCursor",
"PsycopgSyncRawCursor",
Expand Down
20 changes: 10 additions & 10 deletions sqlspec/adapters/psycopg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections.abc import Sized
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from psycopg import sql as psycopg_sql
from typing_extensions import LiteralString

from sqlspec.adapters.psycopg._typing import PsycopgComposed, PsycopgIdentifier, PsycopgSQL
from sqlspec.core import (
SQL,
DriverParameterProfile,
Expand Down Expand Up @@ -88,7 +88,7 @@ class PreparedStackOperation(NamedTuple):
operation_index: int
operation: "StackOperation"
statement: "SQL"
sql: "LiteralString | psycopg_sql.SQL"
sql: "LiteralString | PsycopgSQL | PsycopgComposed"
parameters: "tuple[Any, ...] | dict[str, Any] | None"


Expand All @@ -113,23 +113,23 @@ def pipeline_supported() -> bool:
return False


def _compose_table_identifier(table: str) -> "psycopg_sql.Composed":
def _compose_table_identifier(table: str) -> "PsycopgComposed":
parts = [part for part in table.split(".") if part]
if not parts:
msg = "Table name must not be empty"
raise SQLSpecError(msg)
identifiers = [psycopg_sql.Identifier(part) for part in parts]
return psycopg_sql.SQL(".").join(identifiers)
identifiers = [PsycopgIdentifier(part) for part in parts]
return PsycopgSQL(".").join(identifiers)


def build_copy_from_command(table: str, columns: "list[str]") -> "psycopg_sql.Composed":
def build_copy_from_command(table: str, columns: "list[str]") -> "PsycopgComposed":
table_identifier = _compose_table_identifier(table)
column_sql = psycopg_sql.SQL(", ").join([psycopg_sql.Identifier(column) for column in columns])
return psycopg_sql.SQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql)
column_sql = PsycopgSQL(", ").join([PsycopgIdentifier(column) for column in columns])
return PsycopgSQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql)


def build_truncate_command(table: str) -> "psycopg_sql.Composed":
return psycopg_sql.SQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table))
def build_truncate_command(table: str) -> "PsycopgComposed":
return PsycopgSQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table))


def _identity(value: Any) -> Any:
Expand Down
9 changes: 5 additions & 4 deletions sqlspec/adapters/psycopg/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import TYPE_CHECKING, Any, cast

import psycopg
from psycopg import sql as psycopg_sql
from typing_extensions import LiteralString

from sqlspec.adapters.psycopg._typing import (
PsycopgAsyncConnection,
PsycopgAsyncCursor,
PsycopgAsyncSessionContext,
PsycopgComposed,
PsycopgSQL,
PsycopgSyncConnection,
PsycopgSyncCursor,
PsycopgSyncSessionContext,
Expand Down Expand Up @@ -111,7 +112,7 @@ def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[Prepare
operation_index=index,
operation=operation,
statement=sql_statement,
sql=cast("LiteralString | psycopg_sql.SQL", sql_text),
sql=cast("LiteralString | PsycopgSQL | PsycopgComposed", sql_text),
parameters=prepared_parameters,
)
)
Expand Down Expand Up @@ -396,7 +397,7 @@ def _raise_pending_exception(exception_ctx: "PsycopgSyncExceptionHandler") -> No
cursor = resource_stack.enter_context(self.with_cursor(self.connection))

try:
sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast]
sql = cast("LiteralString | PsycopgSQL | PsycopgComposed", prepared.sql) # type: ignore[redundant-cast]
if prepared.parameters:
cursor.execute(sql, prepared.parameters)
else:
Expand Down Expand Up @@ -855,7 +856,7 @@ def _raise_pending_exception(exception_ctx: "PsycopgAsyncExceptionHandler") -> N
cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection))

try:
sql = cast("LiteralString | psycopg_sql.SQL", prepared.sql) # type: ignore[redundant-cast]
sql = cast("LiteralString | PsycopgSQL | PsycopgComposed", prepared.sql) # type: ignore[redundant-cast]
if prepared.parameters:
await cursor.execute(sql, prepared.parameters)
else:
Expand Down
17 changes: 11 additions & 6 deletions sqlspec/builder/_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,29 @@
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.generator import Generator
from sqlglot.generators.bigquery import BigQueryGenerator
from sqlglot.generators.duckdb import DuckDBGenerator
from sqlglot.generators.oracle import OracleGenerator
from sqlglot.generators.postgres import PostgresGenerator
from sqlglot.generators.snowflake import SnowflakeGenerator

__all__ = ("create_temporal_table", "register_version_generators")


def _oracle_version_sql(self: "Oracle.Generator", expression: exp.Version) -> str:
def _oracle_version_sql(self: OracleGenerator, expression: exp.Version) -> str:
"""Oracle: AS OF TIMESTAMP timestamp or AS OF SCN scn."""
expr = self.sql(expression, "expression")
this = expression.name or "TIMESTAMP"
return f"AS OF {this} {expr}"


def _bigquery_version_sql(self: "BigQuery.Generator", expression: exp.Version) -> str:
def _bigquery_version_sql(self: BigQueryGenerator, expression: exp.Version) -> str:
"""BigQuery: FOR SYSTEM_TIME AS OF timestamp."""
expr = self.sql(expression, "expression")
return f"FOR SYSTEM_TIME AS OF {expr}"


def _snowflake_version_sql(self: "Snowflake.Generator", expression: exp.Version) -> str:
def _snowflake_version_sql(self: SnowflakeGenerator, expression: exp.Version) -> str:
"""Snowflake: AT (TIMESTAMP => timestamp) or BEFORE (TIMESTAMP => ...).

AS OF is mapped to AT, and BEFORE is supported for point-before queries.
Expand All @@ -48,19 +53,19 @@ def _snowflake_version_sql(self: "Snowflake.Generator", expression: exp.Version)
return f"AT ({this} => {expr})"


def _duckdb_version_sql(self: "DuckDB.Generator", expression: exp.Version) -> str:
def _duckdb_version_sql(self: DuckDBGenerator, expression: exp.Version) -> str:
"""DuckDB: AT (TIMESTAMP => timestamp)."""
expr = self.sql(expression, "expression")
return f"AT (TIMESTAMP => {expr})"


def _cockroachdb_version_sql(self: "Postgres.Generator", expression: exp.Version) -> str:
def _cockroachdb_version_sql(self: PostgresGenerator, expression: exp.Version) -> str:
"""CockroachDB (via Postgres dialect): AS OF SYSTEM TIME timestamp."""
expr = self.sql(expression, "expression")
return f"AS OF SYSTEM TIME {expr}"


def _default_version_sql(self: "Generator", expression: exp.Version) -> str:
def _default_version_sql(self: Generator, expression: exp.Version) -> str:
"""Default: AS OF SYSTEM TIME timestamp (CockroachDB style).

When no dialect is specified, we default to CockroachDB/Postgres style
Expand Down
44 changes: 34 additions & 10 deletions sqlspec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,20 +968,21 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):

__slots__ = (
"_migration_commands",
"_migration_config",
"_migration_loader",
"_observability_runtime",
"_storage_capabilities",
"bind_key",
"connection_instance",
"driver_features",
"extension_config",
"migration_config",
"observability_config",
"statement_config",
)

_migration_loader: "SQLFileLoader"
_migration_commands: "SyncMigrationCommands[Any] | AsyncMigrationCommands[Any]"
_migration_config: "dict[str, Any] | MigrationConfig"
driver_type: "ClassVar[type[Any]]"
connection_type: "ClassVar[type[Any]]"
is_async: "ClassVar[bool]" = False
Expand All @@ -998,7 +999,6 @@ class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]):
bind_key: "str | None"
statement_config: "StatementConfig"
connection_instance: "PoolT | None"
migration_config: "dict[str, Any] | MigrationConfig"
extension_config: "ExtensionConfigs"
driver_features: "dict[str, Any]"
_storage_capabilities: "StorageCapabilities | None"
Expand All @@ -1022,6 +1022,20 @@ def __repr__(self) -> str:
])
return f"{type(self).__name__}({parts})"

@property
def migration_config(self) -> "dict[str, Any] | MigrationConfig":
"""Return the current migration configuration."""
return self._migration_config

@migration_config.setter
def migration_config(self, value: "dict[str, Any] | MigrationConfig | None") -> None:
"""Store migration configuration and refresh derived migration helpers."""
object.__setattr__(self, "_migration_config", dict(cast("dict[str, Any]", value) or {}))
if self._has_initialized_attribute("extension_config"):
self._ensure_extension_migrations()
if self._migration_components_ready():
self._initialize_migration_components()

def storage_capabilities(self) -> "StorageCapabilities":
"""Return cached storage capabilities for this configuration."""

Expand All @@ -1034,6 +1048,20 @@ def reset_storage_capabilities_cache(self) -> None:

self._storage_capabilities = None

def _has_initialized_attribute(self, attribute_name: str) -> bool:
"""Return whether a slot-backed attribute has been initialized."""
try:
object.__getattribute__(self, attribute_name)
except AttributeError:
return False
return True

def _migration_components_ready(self) -> bool:
"""Return whether migration helpers have already been initialized."""
return self._has_initialized_attribute("_migration_loader") and self._has_initialized_attribute(
"_migration_commands"
)

def _ensure_extension_migrations(self) -> None:
"""Auto-include extension migrations when extension_config has them configured.

Expand Down Expand Up @@ -1473,8 +1501,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -1637,8 +1664,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -1806,8 +1832,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down Expand Up @@ -2017,8 +2042,7 @@ def __init__(
self.connection_instance = connection_instance
self.connection_config = connection_config or {}
self.extension_config = extension_config or {}
self.migration_config: dict[str, Any] | MigrationConfig = migration_config or {}
self._ensure_extension_migrations()
self.migration_config = migration_config or {}
self._init_observability(observability_config)
self._initialize_migration_components()

Expand Down
Loading
Loading