Skip to content
This repository was archived by the owner on Mar 13, 2026. It is now read-only.

Commit d947102

Browse files
committed
feat: support named schemas
Add support for SELECT, UPDATE and DELETE statements against tables in schemas. Schema names are not allowed in Spanner SELECT statements. We need to avoid generating SQL like ```sql SELECT schema.tbl.id FROM schema.tbl ``` To do so, we alias the table in order to produce SQL like: ```sql SELECT tbl_1.id, tbl_1.col FROM schema.tbl AS tbl_1 ```
1 parent e17c5ef commit d947102

3 files changed

Lines changed: 134 additions & 2 deletions

File tree

README.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ Other limitations
465465
~~~~~~~~~~~~~~~~~
466466

467467
- WITH RECURSIVE statement is not supported.
468-
- Named schemas are not supported.
469468
- Temporary tables are not supported.
470469
- Numeric type dimensions (scale and precision) are constant. See the
471470
`docs <https://cloud.google.com/spanner/docs/data-types#numeric_types>`__.

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler):
233233

234234
compound_keywords = _compound_keywords
235235

236+
def __init__(self, *args, **kwargs):
237+
self.tablealiases = {}
238+
super().__init__(*args, **kwargs)
239+
236240
def get_from_hint_text(self, _, text):
237241
"""Return a hint text.
238242
@@ -378,8 +382,10 @@ def limit_clause(self, select, **kw):
378382
return text
379383

380384
def returning_clause(self, stmt, returning_cols, **kw):
385+
# Set include_table=False because although table names are allowed in
386+
# RETURNING clauses, schema names are not.
381387
columns = [
382-
self._label_select_column(None, c, True, False, {})
388+
self._label_select_column(None, c, True, False, {}, include_table=False)
383389
for c in expression._select_iterables(returning_cols)
384390
]
385391

@@ -391,6 +397,83 @@ def visit_sequence(self, seq, **kw):
391397
seq
392398
)
393399

400+
def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs):
401+
"""Produces the table name.
402+
403+
Schema names are not allowed in Spanner SELECT statements. We
404+
need to avoid generating SQL like
405+
406+
SELECT schema.tbl.id
407+
FROM schema.tbl
408+
409+
To do so, we alias the table in order to produce SQL like:
410+
411+
SELECT tbl_1.id, tbl_1.col
412+
FROM schema.tbl AS tbl_1
413+
414+
And do similar for UPDATE and DELETE statements.
415+
416+
We don't need to correct INSERT statements, which is fortunate
417+
because INSERT statements actually do not currently result in
418+
calls to `visit_table`.
419+
420+
This closely mirrors the mssql dialect which also avoids
421+
schema-qualified columns in SELECTs, although the behaviour is
422+
currently behind a deprecated 'legacy_schema_aliasing' flag.
423+
"""
424+
if spanner_aliased is table or self.isinsert:
425+
return super().visit_table(table, **kwargs)
426+
427+
# alias schema-qualified tables
428+
alias = self._schema_aliased_table(table)
429+
if alias is not None:
430+
return self.process(alias, spanner_aliased=table, **kwargs)
431+
else:
432+
return super().visit_table(table, **kwargs)
433+
434+
def visit_alias(self, alias, **kw):
435+
"""Produces alias statements."""
436+
# translate for schema-qualified table aliases
437+
kw["spanner_aliased"] = alias.element
438+
return super().visit_alias(alias, **kw)
439+
440+
def visit_column(self, column, add_to_result_map=None, **kw):
441+
"""Produces column expressions.
442+
443+
In tandem with visit_table, replaces schema-qualified column
444+
names with column names qualified against an alias.
445+
"""
446+
if column.table is not None and not self.isinsert or self.is_subquery():
447+
# translate for schema-qualified table aliases
448+
t = self._schema_aliased_table(column.table)
449+
if t is not None:
450+
converted = elements._corresponding_column_or_error(t, column)
451+
if add_to_result_map is not None:
452+
add_to_result_map(
453+
column.name,
454+
column.name,
455+
(column, column.name, column.key),
456+
column.type,
457+
)
458+
459+
return super().visit_column(converted, **kw)
460+
461+
return super().visit_column(column, add_to_result_map=add_to_result_map, **kw)
462+
463+
def _schema_aliased_table(self, table):
464+
"""Creates an alias for the table if it is schema-qualified.
465+
466+
If the table is schema-qualified, returns an alias for the
467+
table and caches the alias for future references to the
468+
table. If the table is not schema-qualified, returns None.
469+
"""
470+
if getattr(table, "schema", None) is not None:
471+
if table not in self.tablealiases:
472+
self.tablealiases[table] = table.alias()
473+
return self.tablealiases[table]
474+
else:
475+
return None
476+
394477

395478
class SpannerDDLCompiler(DDLCompiler):
396479
"""Spanner DDL statements compiler."""

test/system/test_basics.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
Boolean,
2626
BIGINT,
2727
select,
28+
update,
29+
delete,
2830
)
2931
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
3032
from sqlalchemy.types import REAL
@@ -58,6 +60,16 @@ def define_tables(cls, metadata):
5860
Column("name", String(20)),
5961
)
6062

63+
with cls.bind.begin() as conn:
64+
conn.execute(text("CREATE SCHEMA IF NOT EXISTS schema"))
65+
Table(
66+
"users",
67+
metadata,
68+
Column("ID", Integer, primary_key=True),
69+
Column("name", String(20)),
70+
schema="schema",
71+
)
72+
6173
def test_hello_world(self, connection):
6274
greeting = connection.execute(text("select 'Hello World'"))
6375
eq_("Hello World", greeting.fetchone()[0])
@@ -139,6 +151,12 @@ class User(Base):
139151
ID: Mapped[int] = mapped_column(primary_key=True)
140152
name: Mapped[str] = mapped_column(String(20))
141153

154+
class SchemaUser(Base):
155+
__tablename__ = "users"
156+
__table_args__ = {"schema": "schema"}
157+
ID: Mapped[int] = mapped_column(primary_key=True)
158+
name: Mapped[str] = mapped_column(String(20))
159+
142160
engine = connection.engine
143161
with Session(engine) as session:
144162
number = Number(
@@ -156,3 +174,35 @@ class User(Base):
156174
users = session.scalars(statement).all()
157175
eq_(1, len(users))
158176
is_true(users[0].ID > 0)
177+
178+
with Session(engine) as session:
179+
user = SchemaUser(name="SchemaTest")
180+
session.add(user)
181+
session.commit()
182+
183+
users = session.scalars(
184+
select(SchemaUser).where(SchemaUser.name == "SchemaTest")
185+
).all()
186+
eq_(1, len(users))
187+
is_true(users[0].ID > 0)
188+
189+
session.execute(
190+
update(SchemaUser)
191+
.where(SchemaUser.name == "SchemaTest")
192+
.values(name="NewName")
193+
)
194+
session.commit()
195+
196+
users = session.scalars(
197+
select(SchemaUser).where(SchemaUser.name == "NewName")
198+
).all()
199+
eq_(1, len(users))
200+
is_true(users[0].ID > 0)
201+
202+
session.execute(delete(SchemaUser).where(SchemaUser.name == "NewName"))
203+
session.commit()
204+
205+
users = session.scalars(
206+
select(SchemaUser).where(SchemaUser.name == "NewName")
207+
).all()
208+
eq_(0, len(users))

0 commit comments

Comments
 (0)