Skip to content

Commit 832fce1

Browse files
authored
Merge branch 'main' into fix/databricks-oauth-shared-connection
2 parents 4f5eeaf + 3be5bba commit 832fce1

18 files changed

Lines changed: 280 additions & 73 deletions

File tree

sqlmesh/core/config/connection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ValidationInfo,
3535
field_validator,
3636
model_validator,
37+
validation_data,
3738
validation_error_message,
3839
get_concrete_types_from_typehint,
3940
)
@@ -1083,7 +1084,7 @@ def validate_execution_project(
10831084
v: t.Optional[str],
10841085
info: ValidationInfo,
10851086
) -> t.Optional[str]:
1086-
if v and not info.data.get("project"):
1087+
if v and not validation_data(info).get("project"):
10871088
raise ConfigError(
10881089
"If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location."
10891090
)
@@ -1095,7 +1096,7 @@ def validate_quota_project(
10951096
v: t.Optional[str],
10961097
info: ValidationInfo,
10971098
) -> t.Optional[str]:
1098-
if v and not info.data.get("project"):
1099+
if v and not validation_data(info).get("project"):
10991100
raise ConfigError(
11001101
"If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location."
11011102
)

sqlmesh/core/context.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,9 +1605,11 @@ def plan_builder(
16051605
backfill_models = None
16061606

16071607
models_override: t.Optional[UniqueKeyDict[str, Model]] = None
1608+
selected_fqns: t.Set[str] = set()
1609+
selected_deletion_fqns: t.Set[str] = set()
16081610
if select_models:
16091611
try:
1610-
models_override = model_selector.select_models(
1612+
models_override, selected_fqns = model_selector.select_models(
16111613
select_models,
16121614
environment,
16131615
fallback_env_name=create_from or c.PROD,
@@ -1622,12 +1624,17 @@ def plan_builder(
16221624
# Only backfill selected models unless explicitly specified.
16231625
backfill_models = model_selector.expand_model_selections(select_models)
16241626

1627+
if not backfill_models:
1628+
# The selection matched nothing locally. Check whether it matched models
1629+
# in the deployed environment that were deleted locally.
1630+
selected_deletion_fqns = selected_fqns - set(self._models)
1631+
16251632
expanded_restate_models = None
16261633
if restate_models is not None:
16271634
expanded_restate_models = model_selector.expand_model_selections(restate_models)
16281635

16291636
if (restate_models is not None and not expanded_restate_models) or (
1630-
backfill_models is not None and not backfill_models
1637+
backfill_models is not None and not backfill_models and not selected_deletion_fqns
16311638
):
16321639
raise PlanError(
16331640
"Selector did not return any models. Please check your model selection and try again."
@@ -1636,7 +1643,7 @@ def plan_builder(
16361643
if always_include_local_changes is None:
16371644
# default behaviour - if restatements are detected; we operate entirely out of state and ignore local changes
16381645
force_no_diff = restate_models is not None or (
1639-
backfill_models is not None and not backfill_models
1646+
backfill_models is not None and not backfill_models and not selected_deletion_fqns
16401647
)
16411648
else:
16421649
force_no_diff = not always_include_local_changes

sqlmesh/core/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def _sanitize_name(cls, v: str) -> str:
5656
@classmethod
5757
def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool:
5858
if v is None:
59-
return info.field_name == "normalize_name"
59+
# Pydantic 2.13+ sets field_name to None during model_validate_json()
60+
return (info.field_name or "") == "normalize_name"
6061
return bool(v)
6162

6263
@t.overload

sqlmesh/core/metric/definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlmesh.core.node import str_or_exp_to_str
1111
from sqlmesh.utils import UniqueKeyDict
1212
from sqlmesh.utils.errors import ConfigError
13-
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
13+
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data
1414

1515
MeasureAndDimTables = t.Tuple[str, t.Tuple[str, ...]]
1616

@@ -89,7 +89,7 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]:
8989
@field_validator("expression", mode="before")
9090
def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expr:
9191
if isinstance(v, str):
92-
dialect = info.data.get("dialect")
92+
dialect = validation_data(info).get("dialect")
9393
return d.parse_one(v, dialect=dialect)
9494
if isinstance(v, exp.Expr):
9595
return v

sqlmesh/core/model/common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121
prepare_env,
2222
serialize_env,
2323
)
24-
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect
24+
from sqlmesh.utils.pydantic import (
25+
PydanticModel,
26+
ValidationInfo,
27+
field_validator,
28+
get_dialect,
29+
validation_data,
30+
)
2531

2632
if t.TYPE_CHECKING:
2733
from sqlglot.dialects.dialect import DialectType
@@ -479,7 +485,7 @@ def parse_expression(
479485
if callable(v):
480486
return v
481487

482-
dialect = info.data.get("dialect") if info else ""
488+
dialect = validation_data(info).get("dialect") if info else ""
483489

484490
if isinstance(v, list):
485491
return [
@@ -519,7 +525,7 @@ def parse_properties(
519525
if v is None:
520526
return v
521527

522-
dialect = info.data.get("dialect") if info else ""
528+
dialect = validation_data(info).get("dialect") if info else ""
523529

524530
if isinstance(v, str):
525531
v = d.parse_one(v, dialect=dialect)
@@ -557,8 +563,9 @@ def default_catalog(cls: t.Type, v: t.Any) -> t.Optional[str]:
557563

558564

559565
def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[str]]:
560-
dialect = info.data.get("dialect")
561-
default_catalog = info.data.get("default_catalog")
566+
data = validation_data(info)
567+
dialect = data.get("dialect")
568+
default_catalog = data.get("default_catalog")
562569

563570
if isinstance(v, exp.Paren):
564571
v = v.unnest()

sqlmesh/core/model/meta.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
list_of_fields_validator,
4545
model_validator,
4646
get_dialect,
47+
validation_data,
4748
)
4849

4950
if t.TYPE_CHECKING:
@@ -135,7 +136,7 @@ def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:
135136

136137
@field_validator("tags", mode="before")
137138
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:
138-
return ensure_list(cls._validate_value_or_tuple(v, info.data))
139+
return ensure_list(cls._validate_value_or_tuple(v, validation_data(info)))
139140

140141
@classmethod
141142
def _validate_value_or_tuple(
@@ -164,7 +165,7 @@ def _normalize(value: t.Any) -> t.Any:
164165
@field_validator("table_format", "storage_format", mode="before")
165166
def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]:
166167
if isinstance(v, exp.Expr) and not (isinstance(v, (exp.Literal, exp.Identifier))):
167-
return v.sql(info.data.get("dialect"))
168+
return v.sql(validation_data(info).get("dialect"))
168169
return str_or_exp_to_str(v)
169170

170171
@field_validator("dialect", mode="before")
@@ -192,7 +193,7 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L
192193
if (
193194
isinstance(v, list)
194195
and all(isinstance(i, str) for i in v)
195-
and info.field_name == "partitioned_by_"
196+
and (info.field_name or "") == "partitioned_by_"
196197
):
197198
# this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str]
198199
# however, we should only invoke this if the list contains strings because this validator is also
@@ -205,7 +206,7 @@ def _partition_and_cluster_validator(cls, v: t.Any, info: ValidationInfo) -> t.L
205206
)
206207
v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v
207208

208-
expressions = list_of_fields_validator(v, info.data)
209+
expressions = list_of_fields_validator(v, validation_data(info))
209210

210211
for expression in expressions:
211212
num_cols = len(list(expression.find_all(exp.Column)))
@@ -228,7 +229,7 @@ def _columns_validator(
228229
cls, v: t.Any, info: ValidationInfo
229230
) -> t.Optional[t.Dict[str, exp.DataType]]:
230231
columns_to_types = {}
231-
dialect = info.data.get("dialect")
232+
dialect = validation_data(info).get("dialect")
232233

233234
if isinstance(v, exp.Schema):
234235
for column in v.expressions:
@@ -280,7 +281,8 @@ def _columns_validator(
280281
def _column_descriptions_validator(
281282
cls, vs: t.Any, info: ValidationInfo
282283
) -> t.Optional[t.Dict[str, str]]:
283-
dialect = info.data.get("dialect")
284+
data = validation_data(info)
285+
dialect = data.get("dialect")
284286

285287
if vs is None:
286288
return None
@@ -302,23 +304,23 @@ def _column_descriptions_validator(
302304
for k, v in raw_col_descriptions.items()
303305
}
304306

305-
columns_to_types = info.data.get("columns_to_types_")
307+
columns_to_types = data.get("columns_to_types_")
306308
if columns_to_types:
307309
from sqlmesh.core.console import get_console
308310

309311
console = get_console()
310312
for column_name in list(col_descriptions):
311313
if column_name not in columns_to_types:
312314
console.log_warning(
313-
f"In model '{info.data['name']}', a description is provided for column '{column_name}' but it is not a column in the model."
315+
f"In model '{data.get('name', '<unknown>')}', a description is provided for column '{column_name}' but it is not a column in the model."
314316
)
315317
del col_descriptions[column_name]
316318

317319
return col_descriptions
318320

319321
@field_validator("grains", "references", mode="before")
320322
def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expr]:
321-
dialect = info.data.get("dialect")
323+
dialect = validation_data(info).get("dialect")
322324

323325
if isinstance(vs, exp.Paren):
324326
vs = vs.unnest()

sqlmesh/core/selector.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def select_models(
6262
target_env_name: str,
6363
fallback_env_name: t.Optional[str] = None,
6464
ensure_finalized_snapshots: bool = False,
65-
) -> UniqueKeyDict[str, Model]:
65+
) -> t.Tuple[UniqueKeyDict[str, Model], t.Set[str]]:
6666
"""Given a set of selections returns models from the current state with names matching the
6767
selection while sourcing the remaining models from the target environment.
6868
@@ -76,29 +76,11 @@ def select_models(
7676
the environment is not finalized.
7777
7878
Returns:
79-
A dictionary of models.
79+
A tuple of (models dict, set of all matched FQNs including env models).
8080
"""
81-
target_env = self._state_reader.get_environment(Environment.sanitize_name(target_env_name))
82-
if target_env and target_env.expired:
83-
target_env = None
84-
85-
if not target_env and fallback_env_name:
86-
target_env = self._state_reader.get_environment(
87-
Environment.sanitize_name(fallback_env_name)
88-
)
89-
90-
env_models: t.Dict[str, Model] = {}
91-
if target_env:
92-
environment_snapshot_infos = (
93-
target_env.snapshots
94-
if not ensure_finalized_snapshots
95-
else target_env.finalized_or_current_snapshots
96-
)
97-
env_models = {
98-
s.name: s.model
99-
for s in self._state_reader.get_snapshots(environment_snapshot_infos).values()
100-
if s.is_model
101-
}
81+
env_models = self._load_env_models(
82+
target_env_name, fallback_env_name, ensure_finalized_snapshots
83+
)
10284

10385
all_selected_models = self.expand_model_selections(
10486
model_selections, models={**env_models, **self._models}
@@ -166,7 +148,37 @@ def get_model(fqn: str) -> t.Optional[Model]:
166148
if needs_update:
167149
update_model_schemas(dag, models=models, cache_dir=self._cache_dir)
168150

169-
return models
151+
return models, all_selected_models
152+
153+
def _load_env_models(
154+
self,
155+
target_env_name: str,
156+
fallback_env_name: t.Optional[str] = None,
157+
ensure_finalized_snapshots: bool = False,
158+
) -> t.Dict[str, "Model"]:
159+
"""Loads models from the target environment, falling back to the fallback environment if needed."""
160+
target_env = self._state_reader.get_environment(Environment.sanitize_name(target_env_name))
161+
if target_env and target_env.expired:
162+
target_env = None
163+
164+
if not target_env and fallback_env_name:
165+
target_env = self._state_reader.get_environment(
166+
Environment.sanitize_name(fallback_env_name)
167+
)
168+
169+
if not target_env:
170+
return {}
171+
172+
environment_snapshot_infos = (
173+
target_env.snapshots
174+
if not ensure_finalized_snapshots
175+
else target_env.finalized_or_current_snapshots
176+
)
177+
return {
178+
s.name: s.model
179+
for s in self._state_reader.get_snapshots(environment_snapshot_infos).values()
180+
if s.is_model
181+
}
170182

171183
def expand_model_selections(
172184
self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Node]] = None

sqlmesh/core/state_sync/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic_core.core_schema import ValidationInfo
1212
from sqlglot import exp
1313

14-
from sqlmesh.utils.pydantic import PydanticModel, field_validator
14+
from sqlmesh.utils.pydantic import PydanticModel, field_validator, validation_data
1515
from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo
1616
from sqlmesh.core.snapshot import (
1717
Snapshot,
@@ -269,7 +269,7 @@ class PromotionResult(PydanticModel):
269269
def _validate_removed_environment_naming_info(
270270
cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo
271271
) -> t.Optional[EnvironmentNamingInfo]:
272-
if v and not info.data.get("removed"):
272+
if v and not validation_data(info).get("removed"):
273273
raise ValueError("removed_environment_naming_info must be None if removed is empty")
274274
return v
275275

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def unpause_snapshots(
138138
):
139139
self.engine_adapter.update_table(
140140
self.snapshots_table,
141-
{"unpaused_ts": None, "updated_ts": updated_ts},
141+
{"unpaused_ts": None},
142142
where=where,
143143
)
144144

sqlmesh/core/user.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from enum import Enum
33

44
from sqlmesh.core.notification_target import BasicSMTPNotificationTarget, NotificationTarget
5-
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
5+
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, validation_data
66

77

88
class UserRole(str, Enum):
@@ -42,7 +42,7 @@ def validate_notification_targets(
4242
v: t.List[NotificationTarget],
4343
info: ValidationInfo,
4444
) -> t.List[NotificationTarget]:
45-
email = info.data["email"]
45+
email = validation_data(info).get("email")
4646
for target in v:
4747
if isinstance(target, BasicSMTPNotificationTarget) and target.recipients != {email}:
4848
raise ValueError("Recipient emails do not match user email")

0 commit comments

Comments
 (0)