Skip to content

Commit b5f8f0b

Browse files
authored
Feat: Use internal state instead of Airflow Variables to store Plan DAG specs (#1533)
1 parent b11cbb0 commit b5f8f0b

12 files changed

Lines changed: 180 additions & 72 deletions

File tree

sqlmesh/core/plan/evaluator.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def client(self) -> BaseAirflowClient:
350350
return self._mwaa_client
351351

352352
def _apply_plan(self, plan: Plan, plan_request_id: str) -> None:
353-
from sqlmesh.schedulers.airflow.plan import create_plan_dag_spec
353+
from sqlmesh.schedulers.airflow.plan import PlanDagState, create_plan_dag_spec
354354

355355
plan_application_request = airflow_common.PlanApplicationRequest(
356356
new_snapshots=list(plan.new_snapshots),
@@ -367,13 +367,7 @@ def _apply_plan(self, plan: Plan, plan_request_id: str) -> None:
367367
forward_only=plan.forward_only,
368368
)
369369
plan_dag_spec = create_plan_dag_spec(plan_application_request, self.state_sync)
370-
371-
_, stderr = self._mwaa_client.set_variable(
372-
airflow_common.plan_dag_spec_key(plan_request_id), plan_dag_spec.json()
373-
)
374-
375-
if stderr:
376-
logger.warning("MWAA CLI stderr:\n%s", stderr)
370+
PlanDagState.from_state_sync(self.state_sync).add_dag_spec(plan_dag_spec)
377371

378372

379373
def can_evaluate_before_promote(

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
self.environments_table = nullsafe_join(".", self.schema, "_environments")
8686
self.seeds_table = nullsafe_join(".", self.schema, "_seeds")
8787
self.intervals_table = nullsafe_join(".", self.schema, "_intervals")
88+
self.plan_dags_table = nullsafe_join(".", self.schema, "_plan_dags")
8889
self.versions_table = nullsafe_join(".", self.schema, "_versions")
8990

9091
self._snapshot_columns_to_types = {
@@ -306,7 +307,7 @@ def _update_snapshot(self, snapshot: Snapshot) -> None:
306307
snapshot.updated_ts = now_timestamp()
307308
self.engine_adapter.update_table(
308309
self.snapshots_table,
309-
{"snapshot": snapshot.json()},
310+
{"snapshot": _snapshot_to_json(snapshot)},
310311
where=self._snapshot_id_filter([snapshot.snapshot_id]),
311312
contains_json=True,
312313
)
@@ -726,22 +727,22 @@ def rollback(self) -> None:
726727
"""Rollback to the previous migration."""
727728
logger.info("Starting migration rollback.")
728729
tables = (self.snapshots_table, self.environments_table, self.versions_table)
730+
optional_tables = (self.seeds_table, self.intervals_table, self.plan_dags_table)
729731
versions = self.get_versions(validate=False)
730732
if versions.schema_version == 0:
731733
# Clean up state tables
732-
for table in tables + (self.seeds_table, self.intervals_table):
734+
for table in tables + optional_tables:
733735
self.engine_adapter.drop_table(table)
734736
else:
735737
if not all(self.engine_adapter.table_exists(f"{table}_backup") for table in tables):
736738
raise SQLMeshError("There are no prior migrations to roll back to.")
737739
for table in tables:
738740
self._restore_table(table, _backup_table_name(table))
739741

740-
if self.engine_adapter.table_exists(_backup_table_name(self.seeds_table)):
741-
self._restore_table(self.seeds_table, _backup_table_name(self.seeds_table))
742+
for optional_table in optional_tables:
743+
if self.engine_adapter.table_exists(_backup_table_name(optional_table)):
744+
self._restore_table(optional_table, _backup_table_name(optional_table))
742745

743-
if self.engine_adapter.table_exists(_backup_table_name(self.intervals_table)):
744-
self._restore_table(self.intervals_table, _backup_table_name(self.intervals_table))
745746
logger.info("Migration rollback successful.")
746747

747748
def _backup_state(self) -> None:
@@ -751,6 +752,7 @@ def _backup_state(self) -> None:
751752
self.versions_table,
752753
self.seeds_table,
753754
self.intervals_table,
755+
self.plan_dags_table,
754756
):
755757
if self.engine_adapter.table_exists(table):
756758
with self.engine_adapter.transaction(TransactionType.DDL):
@@ -1000,7 +1002,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
10001002
"name": snapshot.name,
10011003
"identifier": snapshot.identifier,
10021004
"version": snapshot.version,
1003-
"snapshot": snapshot.json(exclude={"intervals", "dev_intervals"}),
1005+
"snapshot": _snapshot_to_json(snapshot),
10041006
"kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None,
10051007
}
10061008
for snapshot in snapshots
@@ -1033,3 +1035,7 @@ def _environment_to_df(environment: Environment) -> pd.DataFrame:
10331035

10341036
def _backup_table_name(table_name: str) -> str:
10351037
return f"{table_name}_backup"
1038+
1039+
1040+
def _snapshot_to_json(snapshot: Snapshot) -> str:
1041+
return snapshot.json(exclude={"intervals", "dev_intervals"})
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Creates the '_plan_dags' table if Airflow is used."""
2+
from sqlglot import exp
3+
4+
from sqlmesh.utils.migration import index_text_type
5+
6+
7+
def migrate(state_sync): # type: ignore
8+
engine_adapter = state_sync.engine_adapter
9+
schema = state_sync.schema
10+
plan_dags_table = "_plan_dags"
11+
12+
if schema:
13+
engine_adapter.create_schema(schema)
14+
plan_dags_table = f"{schema}.{plan_dags_table}"
15+
16+
text_type = index_text_type(engine_adapter.dialect)
17+
18+
engine_adapter.create_state_table(
19+
plan_dags_table,
20+
{
21+
"request_id": exp.DataType.build(text_type),
22+
"dag_id": exp.DataType.build(text_type),
23+
"dag_spec": exp.DataType.build("text"),
24+
},
25+
primary_key=("request_id",),
26+
)
27+
28+
engine_adapter.create_index(plan_dags_table, "dag_id_idx", ("dag_id",))

sqlmesh/schedulers/airflow/api.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
from functools import wraps
66

77
from airflow.api_connexion import security
8-
from airflow.models import Variable
98
from airflow.www.app import csrf
109
from flask import Blueprint, Response, jsonify, make_response, request
1110

1211
from sqlmesh.core import constants as c
1312
from sqlmesh.core.snapshot import SnapshotId, SnapshotNameVersion
1413
from sqlmesh.schedulers.airflow import common, util
15-
from sqlmesh.schedulers.airflow.plan import create_plan_dag_spec
14+
from sqlmesh.schedulers.airflow.plan import PlanDagState, create_plan_dag_spec
1615
from sqlmesh.utils.errors import SQLMeshError
1716
from sqlmesh.utils.pydantic import PydanticModel
1817

@@ -40,13 +39,11 @@ def apply_plan() -> Response:
4039
plan = common.PlanApplicationRequest.parse_obj(request.json or {})
4140
with util.scoped_state_sync() as state_sync:
4241
spec = create_plan_dag_spec(plan, state_sync)
42+
PlanDagState.from_state_sync(state_sync).add_dag_spec(spec)
43+
return make_response(jsonify(request_id=spec.request_id), 201)
4344
except Exception as ex:
4445
return _error(str(ex))
4546

46-
Variable.set(common.plan_dag_spec_key(spec.request_id), spec.json())
47-
48-
return make_response(jsonify(request_id=spec.request_id), 201)
49-
5047

5148
@sqlmesh_api_v1.route("/environments/<name>")
5249
@csrf.exempt

sqlmesh/schedulers/airflow/common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,17 @@ class PlanDagSpec(PydanticModel):
6868
promoted_snapshots: t.List[SnapshotTableInfo]
6969
demoted_snapshots: t.List[SnapshotTableInfo]
7070
start: TimeLike
71-
end: t.Optional[TimeLike]
72-
unpaused_dt: t.Optional[TimeLike]
71+
end: t.Optional[TimeLike] = None
72+
unpaused_dt: t.Optional[TimeLike] = None
7373
no_gaps: bool
7474
plan_id: str
75-
previous_plan_id: t.Optional[str]
75+
previous_plan_id: t.Optional[str] = None
7676
notification_targets: t.List[NotificationTarget]
7777
backfill_concurrent_tasks: int
7878
ddl_concurrent_tasks: int
7979
users: t.List[User]
8080
is_dev: bool
81-
forward_only: t.Optional[bool]
81+
forward_only: t.Optional[bool] = None
8282
environment_expiration_ts: t.Optional[int] = None
8383
dag_start_ts: t.Optional[int] = None
8484

@@ -143,10 +143,6 @@ def plan_application_dag_id(environment: str, request_id: str) -> str:
143143
return f"sqlmesh_plan_application__{environment}__{request_id}"
144144

145145

146-
def environment_key(env: str) -> str:
147-
return f"{ENV_KEY_PREFIX}__{env}"
148-
149-
150146
def plan_dag_spec_key(request_id: str) -> str:
151147
return f"{PLAN_DAG_SPEC_KEY_PREFIX}__{request_id}"
152148

sqlmesh/schedulers/airflow/integration.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sqlmesh.schedulers.airflow import common, util
1616
from sqlmesh.schedulers.airflow.dag_generator import SnapshotDagGenerator
1717
from sqlmesh.schedulers.airflow.operators import targets
18+
from sqlmesh.schedulers.airflow.plan import PlanDagState
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -99,6 +100,9 @@ def dags(self) -> t.List[DAG]:
99100
"""
100101
with util.scoped_state_sync() as state_sync:
101102
stored_snapshots = state_sync.get_snapshots(None)
103+
plan_dag_specs = PlanDagState.from_state_sync(state_sync).get_dag_specs()
104+
# TODO: Remove this once all DAG specs are moved into the internal state (after about 1 week)
105+
plan_dag_specs += _get_plan_dag_specs_from_variables()
102106

103107
dag_generator = SnapshotDagGenerator(
104108
self._engine_operator,
@@ -112,7 +116,7 @@ def dags(self) -> t.List[DAG]:
112116
cadence_dags = dag_generator.generate_cadence_dags(prod_env.snapshots) if prod_env else []
113117

114118
plan_application_dags = [
115-
dag_generator.generate_plan_application_dag(s) for s in _get_plan_dag_specs()
119+
dag_generator.generate_plan_application_dag(s) for s in plan_dag_specs
116120
]
117121

118122
system_dags = [
@@ -187,6 +191,8 @@ def _janitor_task(
187191
ttl=plan_application_dag_ttl, session=session
188192
)
189193
logger.info("Deleting expired Plan Application DAGs: %s", plan_application_dag_ids)
194+
PlanDagState.from_state_sync(state_sync).delete_dag_specs(plan_application_dag_ids)
195+
# TODO: Remove this once all DAG specs are moved into the internal state (after about 1 week)
190196
util.delete_variables(
191197
{common.plan_dag_spec_key_from_dag_id(dag_id) for dag_id in plan_application_dag_ids},
192198
session=session,
@@ -197,7 +203,7 @@ def _janitor_task(
197203

198204

199205
@provide_session
200-
def _get_plan_dag_specs(
206+
def _get_plan_dag_specs_from_variables(
201207
session: Session = util.PROVIDED_SESSION,
202208
) -> t.List[common.PlanDagSpec]:
203209
records = (

sqlmesh/schedulers/airflow/mwaa_client.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,6 @@ def __init__(self, airflow_url: str, auth_token: str, console: t.Optional[Consol
2121
{"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"}
2222
)
2323

24-
def set_variable(self, key: str, value: str) -> t.Tuple[str, str]:
25-
"""Sets the Airflow variable.
26-
27-
Args:
28-
key: The name of the variable.
29-
value: The value of the variable.
30-
31-
Returns:
32-
A tuple of stdout and stderr from the MWAA CLI.
33-
"""
34-
value = value.replace("\\", "\\\\").replace('"', '\\"')
35-
return self._post(f'variables set {key} "{value}"')
36-
3724
def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]:
3825
dag_runs = self._list_dag_runs(dag_id)
3926
if dag_runs:

sqlmesh/schedulers/airflow/operators/hwm_sensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ def _compute_target_high_water_mark(
6060
self, dag_run: DagRun, target_snapshot: Snapshot
6161
) -> datetime:
6262
target_date = to_datetime(dag_run.data_interval_end)
63-
target_prev = to_datetime(target_snapshot.model.cron_floor(target_date))
64-
this_prev = to_datetime(self.this_snapshot.model.cron_floor(target_date))
63+
target_prev = to_datetime(target_snapshot.node.cron_floor(target_date))
64+
this_prev = to_datetime(self.this_snapshot.node.cron_floor(target_date))
6565
return min(target_prev, this_prev)

sqlmesh/schedulers/airflow/plan.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,84 @@
22

33
import typing as t
44

5+
import pandas as pd
6+
from sqlglot import exp
7+
58
from sqlmesh.core import scheduler
9+
from sqlmesh.core.engine_adapter import EngineAdapter
610
from sqlmesh.core.environment import Environment
711
from sqlmesh.core.plan import can_evaluate_before_promote
812
from sqlmesh.core.snapshot import SnapshotTableInfo
9-
from sqlmesh.core.state_sync import StateSync
13+
from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync
14+
from sqlmesh.core.state_sync.base import DelegatingStateSync
1015
from sqlmesh.schedulers.airflow import common
1116
from sqlmesh.utils.date import now, now_timestamp
1217
from sqlmesh.utils.errors import SQLMeshError
1318

1419

20+
class PlanDagState:
21+
def __init__(self, engine_adapter: EngineAdapter, plan_dags_table: str):
22+
self.engine_adapter = engine_adapter
23+
24+
self._plan_dags_table = plan_dags_table
25+
self._plan_dag_columns_to_types = {
26+
"request_id": exp.DataType.build("text"),
27+
"dag_id": exp.DataType.build("text"),
28+
"dag_spec": exp.DataType.build("text"),
29+
}
30+
31+
@classmethod
32+
def from_state_sync(cls, state_sync: StateSync) -> PlanDagState:
33+
while isinstance(state_sync, DelegatingStateSync):
34+
state_sync = state_sync.state_sync
35+
if not isinstance(state_sync, EngineAdapterStateSync):
36+
raise SQLMeshError(f"Unsupported state sync {state_sync.__class__.__name__}")
37+
return cls(state_sync.engine_adapter, state_sync.plan_dags_table)
38+
39+
def add_dag_spec(self, spec: common.PlanDagSpec) -> None:
40+
"""Adds a new DAG spec to the state.
41+
42+
Args:
43+
spec: the plan DAG spec to add.
44+
"""
45+
df = pd.DataFrame(
46+
[
47+
{
48+
"request_id": spec.request_id,
49+
"dag_id": common.plan_application_dag_id(
50+
spec.environment_naming_info.name, spec.request_id
51+
),
52+
"dag_spec": spec.json(),
53+
}
54+
]
55+
)
56+
self.engine_adapter.insert_append(
57+
self._plan_dags_table,
58+
df,
59+
columns_to_types=self._plan_dag_columns_to_types,
60+
contains_json=True,
61+
)
62+
63+
def get_dag_specs(self) -> t.List[common.PlanDagSpec]:
64+
"""Returns all DAG specs in the state."""
65+
query = exp.select("dag_spec").from_(self._plan_dags_table)
66+
return [
67+
common.PlanDagSpec.parse_raw(row[0])
68+
for row in self.engine_adapter.fetchall(
69+
query, ignore_unsupported_errors=True, quote_identifiers=True
70+
)
71+
]
72+
73+
def delete_dag_specs(self, dag_ids: t.Collection[str]) -> None:
74+
"""Deletes the DAG specs with the given DAG IDs."""
75+
if not dag_ids:
76+
return
77+
self.engine_adapter.delete_from(
78+
self._plan_dags_table,
79+
where=exp.column("dag_id").isin(*dag_ids),
80+
)
81+
82+
1583
def create_plan_dag_spec(
1684
request: common.PlanApplicationRequest, state_sync: StateSync
1785
) -> common.PlanDagSpec:

tests/core/test_plan_evaluator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,21 @@ def test_mwaa_evaluator(sushi_plan: Plan, mocker: MockerFixture):
125125

126126
state_sync_mock = mocker.Mock()
127127

128-
plan_dag_spec_json = """{"request_id": "test_request_id"}"""
129-
130128
plan_dag_spec_mock = mocker.Mock()
131-
plan_dag_spec_mock.json.return_value = plan_dag_spec_json
132129

133130
create_plan_dag_spec_mock = mocker.patch("sqlmesh.schedulers.airflow.plan.create_plan_dag_spec")
134131
create_plan_dag_spec_mock.return_value = plan_dag_spec_mock
135132

133+
plan_dag_state_mock = mocker.Mock()
134+
mocker.patch(
135+
"sqlmesh.schedulers.airflow.plan.PlanDagState.from_state_sync",
136+
return_value=plan_dag_state_mock,
137+
)
138+
136139
evaluator = MWAAPlanEvaluator(mwaa_client_mock, state_sync_mock)
137140
evaluator.evaluate(sushi_plan)
138141

139-
mwaa_client_mock.set_variable.assert_called_once_with(mocker.ANY, plan_dag_spec_json)
142+
plan_dag_state_mock.add_dag_spec.assert_called_once_with(plan_dag_spec_mock)
140143

141144
mwaa_client_mock.wait_for_dag_run_completion.assert_called_once()
142145
mwaa_client_mock.wait_for_first_dag_run.assert_called_once()

0 commit comments

Comments
 (0)