Skip to content

Commit 9ef6028

Browse files
committed
support Databricks query tags in session properties
1 parent 228c223 commit 9ef6028

2 files changed

Lines changed: 128 additions & 0 deletions

File tree

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
def _query_tags(query_tags: t.Any) -> t.Optional[t.Dict[str, t.Optional[str]]]:
34+
if not query_tags:
35+
return None
36+
37+
if not isinstance(query_tags, dict):
38+
raise SQLMeshError("Invalid value for `session_properties.query_tags`. Must be a dict.")
39+
40+
if not all(
41+
isinstance(key, str) and (value is None or isinstance(value, str))
42+
for key, value in query_tags.items()
43+
):
44+
raise SQLMeshError(
45+
"Invalid value for `session_properties.query_tags`. Keys must be strings "
46+
"and values must be strings or None."
47+
)
48+
49+
return query_tags
50+
51+
3352
class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin):
3453
DIALECT = "databricks"
3554
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
@@ -98,6 +117,13 @@ def _use_spark_session(self) -> bool:
98117
def is_spark_session_connection(self) -> bool:
99118
return isinstance(self.connection, SparkSessionConnection)
100119

120+
@property
121+
def _is_databricks_sql_connector_connection(self) -> bool:
122+
return (
123+
not self.is_spark_session_connection
124+
and not self._connection_pool.get_attribute("use_spark_engine_adapter")
125+
)
126+
101127
def _set_spark_engine_adapter_if_needed(self) -> None:
102128
self._spark_engine_adapter = None
103129

@@ -181,10 +207,25 @@ def _begin_session(self, properties: SessionProperties) -> t.Any:
181207
"""Begin a new session."""
182208
# Align the different possible connectors to a single catalog
183209
self.set_current_catalog(self.default_catalog) # type: ignore
210+
self._connection_pool.set_attribute(
211+
"query_tags", _query_tags(properties.get("query_tags"))
212+
)
184213

185214
def _end_session(self) -> None:
215+
self._connection_pool.set_attribute("query_tags", None)
186216
self._connection_pool.set_attribute("use_spark_engine_adapter", False)
187217

218+
def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None:
219+
query_tags = self._connection_pool.get_attribute("query_tags")
220+
if (
221+
query_tags
222+
and "query_tags" not in kwargs
223+
and self._is_databricks_sql_connector_connection
224+
):
225+
kwargs["query_tags"] = query_tags
226+
227+
return super()._execute(sql, track_rows_processed, **kwargs)
228+
188229
def _df_to_source_queries(
189230
self,
190231
df: DF,

tests/core/engine_adapter/test_databricks.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sqlmesh.core.engine_adapter import DatabricksEngineAdapter
1111
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType
1212
from sqlmesh.core.node import IntervalUnit
13+
from sqlmesh.utils.errors import SQLMeshError
1314
from tests.core.engine_adapter import to_sql_calls
1415

1516
pytestmark = [pytest.mark.databricks, pytest.mark.engine]
@@ -117,6 +118,92 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
117118
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]
118119

119120

121+
def test_session_query_tags(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
122+
mocker.patch(
123+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
124+
)
125+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
126+
127+
with adapter.session({"query_tags": {"team": "data-eng", "app": "sqlmesh"}}):
128+
adapter.execute("SELECT 1")
129+
130+
adapter.cursor.execute.assert_called_with(
131+
"SELECT 1", query_tags={"team": "data-eng", "app": "sqlmesh"}
132+
)
133+
134+
adapter.execute("SELECT 2")
135+
136+
adapter.cursor.execute.assert_called_with("SELECT 2")
137+
138+
139+
def test_session_query_tags_do_not_override_explicit_query_tags(
140+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
141+
):
142+
mocker.patch(
143+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
144+
)
145+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
146+
147+
with adapter.session({"query_tags": {"team": "data-eng"}}):
148+
adapter.execute("SELECT 1", query_tags={"team": "analytics"})
149+
150+
adapter.cursor.execute.assert_called_with("SELECT 1", query_tags={"team": "analytics"})
151+
152+
153+
def test_session_query_tags_not_applied_to_spark_session_connection(
154+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
155+
):
156+
mocker.patch(
157+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
158+
)
159+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
160+
mocker.patch.object(
161+
DatabricksEngineAdapter,
162+
"is_spark_session_connection",
163+
new_callable=mocker.PropertyMock,
164+
return_value=True,
165+
)
166+
167+
with adapter.session({"query_tags": {"team": "data-eng"}}):
168+
adapter.execute("SELECT 1")
169+
170+
adapter.cursor.execute.assert_called_with("SELECT 1")
171+
172+
173+
def test_session_query_tags_not_applied_to_spark_engine_adapter(
174+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
175+
):
176+
mocker.patch(
177+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
178+
)
179+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
180+
spark_cursor = mocker.Mock()
181+
adapter._spark_engine_adapter = mocker.Mock(cursor=spark_cursor)
182+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
183+
184+
with adapter.session({"query_tags": {"team": "data-eng"}}):
185+
adapter._connection_pool.set_attribute("use_spark_engine_adapter", True)
186+
adapter.execute("SELECT 1")
187+
188+
spark_cursor.execute.assert_called_with("SELECT 1")
189+
190+
191+
@pytest.mark.parametrize(
192+
"query_tags",
193+
[
194+
"team:data-eng",
195+
{"team": 1},
196+
{1: "data-eng"},
197+
],
198+
)
199+
def test_session_query_tags_invalid(query_tags, make_mocked_engine_adapter: t.Callable):
200+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
201+
202+
with pytest.raises(SQLMeshError, match="session_properties.query_tags"):
203+
with adapter.session({"query_tags": query_tags}):
204+
pass
205+
206+
120207
def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
121208
mocker.patch(
122209
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"

0 commit comments

Comments
 (0)