Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ def collect_node_ctes(
ctx: BuildContext,
nodes_to_include: list[Node],
needed_columns_by_node: Optional[dict[str, set[str]]] = None,
injected_filters: Optional[dict[str, ast.Expression]] = None,
) -> tuple[list[tuple[str, ast.Query]], list[str]]:
"""
Collect CTEs for all non-source nodes, recursively expanding table references.
Expand All @@ -873,6 +874,10 @@ def collect_node_ctes(
nodes_to_include: List of nodes to create CTEs for
needed_columns_by_node: Optional dict of node_name -> set of column names
If provided, CTEs will only select the needed columns.
injected_filters: Optional dict of node_name -> filter expression to inject
as a WHERE clause into that node's CTE. Used to push temporal partition
filters down into upstream CTEs (e.g. a date-spine) rather than applying
them on the outer query after an expensive join.

Returns:
Tuple of (cte_list, scanned_sources):
Expand Down Expand Up @@ -987,6 +992,17 @@ def collect_refs(node: Node, visited: set[str]) -> None:
if needed_cols: # pragma: no branch
query_ast = filter_cte_projection(query_ast, needed_cols)

# Inject pushed-down filter (e.g. temporal partition) into this CTE's WHERE clause
if injected_filters and node.name in injected_filters:
injected = injected_filters[node.name]
if query_ast.select.where: # pragma: no cover
query_ast.select.where = ast.BinaryOp.And(
query_ast.select.where,
injected,
)
else:
query_ast.select.where = injected

ctes.append((cte_name, query_ast))

return ctes, list(scanned_source_names)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,11 +600,36 @@ def add_table_prefix(e):
else:
needed_columns_by_node[dim_node.name] = dim_cols

# Build temporal filter and attempt to push it into the upstream date-spine CTE.
# Without pushdown, the filter lands on the outer grain-group query AFTER the
# rolling join has already been evaluated across all dates, which is very expensive.
# Pushdown injects the filter directly into the driving CTE (the date spine)
# so the join only runs for the target date(s).
temporal_filter_ast, fk_col_name = build_temporal_filter(
ctx,
parent_node,
main_alias,
)
injected_cte_filters: dict[str, ast.Expression] = {}
if temporal_filter_ast and fk_col_name:
upstream_node = find_upstream_temporal_source_node(
ctx,
parent_node,
fk_col_name,
)
if upstream_node:
# Build an unaliased version of the filter for injection into the upstream CTE
unaliased_filter_ast, _ = build_temporal_filter(ctx, parent_node, None)
if unaliased_filter_ast: # pragma: no branch
injected_cte_filters[upstream_node.name] = unaliased_filter_ast
temporal_filter_ast = None # Don't also apply on the outer query

# Build CTEs for all non-source nodes with column filtering
ctes, scanned_sources = collect_node_ctes(
ctx,
nodes_for_ctes,
needed_columns_by_node,
injected_filters=injected_cte_filters or None,
)

# Build FROM clause with main table (use materialized table if available)
Expand All @@ -626,10 +651,7 @@ def add_table_prefix(e):

from_clause = ast.From(relations=[relation])

# Inject temporal partition filters for incremental materialization
# This ensures partition pruning at the source level
all_filters = list(filters or [])
temporal_filter_ast = build_temporal_filter(ctx, parent_node, main_alias)

# Build WHERE clause from filters
where_clause: Optional[ast.Expression] = None
Expand Down Expand Up @@ -727,21 +749,21 @@ def add_table_prefix(e):
def build_temporal_filter(
ctx: BuildContext,
parent_node: Node,
table_alias: str,
) -> Optional[ast.Expression]:
table_alias: Optional[str],
) -> tuple[Optional[ast.Expression], Optional[str]]:
"""
Build temporal filter expression based on cube's temporal partition columns.

Checks if the parent node has dimension links to any of the cube's temporal
partition columns, and generates filters for those columns.

Returns:
- BinaryOp (col = expr) for exact partition match
- Between (col BETWEEN start AND end) for lookback window
- None if no temporal partition columns from cube are linked to this parent
Tuple of (filter_expression, fk_col_name):
- filter_expression: BinaryOp (col = expr) for exact match, Between for lookback, or None
- fk_col_name: the parent node's FK column name used in the filter, or None
"""
if not ctx.temporal_partition_columns or not parent_node.current:
return None
return None, None

# For each temporal partition column specified by the cube
for partition_col_ref, partition_metadata in ctx.temporal_partition_columns.items():
Expand Down Expand Up @@ -789,16 +811,77 @@ def build_temporal_filter(
expr=col_ref,
low=start_expr,
high=end_expr,
)
), parent_col_name
elif end_expr: # pragma: no branch
# No lookback - exact partition match
return ast.BinaryOp(
left=col_ref,
right=end_expr,
op=ast.BinaryOpKind.Eq,
)
), parent_col_name

return None # pragma: no cover
return None, None # pragma: no cover


def find_upstream_temporal_source_node(
ctx: BuildContext,
parent_node: Node,
fk_col_name: str,
) -> Optional[Node]:
"""
Find the upstream node that directly provides the temporal FK column to parent_node.

Walks parent_node's FROM clause to find the primary (driving) table of any join,
then checks if that table's corresponding DJ node exposes the FK column. This is
used to push temporal filters down into the upstream CTE rather than applying them
on the outer grain-group query.

For example, if parent_node is a rolling-window transform that joins a date-spine
transform against a windowed fact, this returns the date-spine node so the filter
can be injected there — avoiding a full cross-date join before filtering.

Returns the upstream Node if found, otherwise None (caller falls back to outer WHERE).
"""
if not parent_node.current or not parent_node.current.query: # pragma: no cover
return None

try:
query_ast = ctx.get_parsed_query(parent_node)
except Exception: # pragma: no cover
return None

if not query_ast.select or not query_ast.select.from_:
return None

for relation in query_ast.select.from_.relations:
primary = relation.primary

# Unwrap alias: Alias(child=Table(...), alias=Name("dd")) -> Table(...)
table_expr = (
getattr(primary, "child", primary)
if isinstance(primary, ast.Alias)
else primary
)

if not isinstance(table_expr, ast.Table):
continue # pragma: no cover

table_name = str(table_expr.name)
upstream_node = ctx.nodes.get(table_name)
if not upstream_node or not upstream_node.current: # pragma: no cover
continue

# Only push into non-source nodes — source nodes don't become CTEs,
# so injecting there would silently drop the filter
if upstream_node.type == NodeType.SOURCE:
continue

# Check if this upstream node exposes the FK column
upstream_col_names = {col.name for col in (upstream_node.current.columns or [])}
if fk_col_name in upstream_col_names:
return upstream_node

return None


# TODO: Remove this once we have a way to test pre-aggregations
Expand Down
Loading
Loading