1616
1717import abc
1818import dataclasses
19+ import functools
1920import itertools
2021import typing
2122from typing import Generator , Mapping , TypeVar , Union
2223
2324import pandas as pd
2425
26+ from bigframes import dtypes
27+ from bigframes .core import field
2528import bigframes .core .identifiers as ids
26- import bigframes .dtypes as dtypes
2729import bigframes .operations
2830import bigframes .operations .aggregations as agg_ops
2931
@@ -50,7 +52,7 @@ class Aggregation(abc.ABC):
5052
5153 @abc .abstractmethod
5254 def output_type (
53- self , input_types : dict [ids .ColumnId , dtypes . ExpressionType ]
55+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
5456 ) -> dtypes .ExpressionType :
5557 ...
5658
@@ -72,7 +74,7 @@ class NullaryAggregation(Aggregation):
7274 op : agg_ops .NullaryWindowOp = dataclasses .field ()
7375
7476 def output_type (
75- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
77+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
7678 ) -> dtypes .ExpressionType :
7779 return self .op .output_type ()
7880
@@ -86,13 +88,17 @@ def remap_column_refs(
8688
8789@dataclasses .dataclass (frozen = True )
8890class UnaryAggregation (Aggregation ):
89- op : agg_ops .UnaryWindowOp = dataclasses . field ()
90- arg : Union [DerefOp , ScalarConstantExpression ] = dataclasses . field ()
91+ op : agg_ops .UnaryWindowOp
92+ arg : Union [DerefOp , ScalarConstantExpression ]
9193
9294 def output_type (
93- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
95+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
9496 ) -> dtypes .ExpressionType :
95- return self .op .output_type (self .arg .output_type (input_types ))
97+ # TODO(b/419300717) Remove resolutions once defers are cleaned up.
98+ resolved_expr = bind_schema_fields (self .arg , input_fields )
99+ assert resolved_expr .is_resolved
100+
101+ return self .op .output_type (resolved_expr .output_type )
96102
97103 @property
98104 def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
@@ -118,10 +124,16 @@ class BinaryAggregation(Aggregation):
118124 right : Union [DerefOp , ScalarConstantExpression ] = dataclasses .field ()
119125
120126 def output_type (
121- self , input_types : dict [ids .ColumnId , bigframes . dtypes . Dtype ]
127+ self , input_fields : Mapping [ids .ColumnId , field . Field ]
122128 ) -> dtypes .ExpressionType :
129+ # TODO(b/419300717) Remove resolutions once defers are cleaned up.
130+ left_resolved_expr = bind_schema_fields (self .left , input_fields )
131+ assert left_resolved_expr .is_resolved
132+ right_resolved_expr = bind_schema_fields (self .right , input_fields )
133+ assert right_resolved_expr .is_resolved
134+
123135 return self .op .output_type (
124- self . left . output_type ( input_types ), self . right . output_type ( input_types )
136+ left_resolved_expr . output_type , left_resolved_expr . output_type
125137 )
126138
127139 @property
@@ -189,10 +201,17 @@ def remap_column_refs(
189201 def is_const (self ) -> bool :
190202 ...
191203
204+ @property
192205 @abc .abstractmethod
193- def output_type (
194- self , input_types : dict [ids .ColumnId , dtypes .ExpressionType ]
195- ) -> dtypes .ExpressionType :
206+ def is_resolved (self ) -> bool :
207+ """
208+ Returns true if and only if the expression's output type and nullability is available.
209+ """
210+ ...
211+
212+ @property
213+ @abc .abstractmethod
214+ def output_type (self ) -> dtypes .ExpressionType :
196215 ...
197216
198217 @abc .abstractmethod
@@ -256,9 +275,12 @@ def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
256275 def nullable (self ) -> bool :
257276 return pd .isna (self .value ) # type: ignore
258277
259- def output_type (
260- self , input_types : dict [ids .ColumnId , bigframes .dtypes .Dtype ]
261- ) -> dtypes .ExpressionType :
278+ @property
279+ def is_resolved (self ) -> bool :
280+ return True
281+
282+ @property
283+ def output_type (self ) -> dtypes .ExpressionType :
262284 return self .dtype
263285
264286 def bind_variables (
@@ -308,9 +330,12 @@ def is_const(self) -> bool:
308330 def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
309331 return ()
310332
311- def output_type (
312- self , input_types : dict [ids .ColumnId , bigframes .dtypes .Dtype ]
313- ) -> dtypes .ExpressionType :
333+ @property
334+ def is_resolved (self ):
335+ return False
336+
337+ @property
338+ def output_type (self ) -> dtypes .ExpressionType :
314339 raise ValueError (f"Type of variable { self .id } has not been fixed." )
315340
316341 def bind_refs (
@@ -340,7 +365,7 @@ def is_identity(self) -> bool:
340365
341366@dataclasses .dataclass (frozen = True )
342367class DerefOp (Expression ):
343- """A variable expression representing an unbound variable ."""
368+ """An expression that refers to a column by ID ."""
344369
345370 id : ids .ColumnId
346371
@@ -357,13 +382,13 @@ def nullable(self) -> bool:
357382 # Safe default, need to actually bind input schema to determine
358383 return True
359384
360- def output_type (
361- self , input_types : dict [ ids . ColumnId , bigframes . dtypes . Dtype ]
362- ) -> dtypes . ExpressionType :
363- if self . id in input_types :
364- return input_types [ self . id ]
365- else :
366- raise ValueError (f"Type of variable { self .id } has not been fixed." )
385+ @ property
386+ def is_resolved ( self ) -> bool :
387+ return False
388+
389+ @ property
390+ def output_type ( self ) -> dtypes . ExpressionType :
391+ raise ValueError (f"Type of variable { self .id } has not been fixed." )
367392
368393 def bind_variables (
369394 self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
@@ -390,6 +415,55 @@ def is_identity(self) -> bool:
390415 return True
391416
392417
418+ @dataclasses .dataclass (frozen = True )
419+ class SchemaFieldRefExpression (Expression ):
420+ """An expression representing a schema field. This is essentially a DerefOp with input schema bound."""
421+
422+ field : field .Field
423+
424+ @property
425+ def column_references (self ) -> typing .Tuple [ids .ColumnId , ...]:
426+ return (self .field .id ,)
427+
428+ @property
429+ def is_const (self ) -> bool :
430+ return False
431+
432+ @property
433+ def nullable (self ) -> bool :
434+ return self .field .nullable
435+
436+ @property
437+ def is_resolved (self ) -> bool :
438+ return True
439+
440+ @property
441+ def output_type (self ) -> dtypes .ExpressionType :
442+ return self .field .dtype
443+
444+ def bind_variables (
445+ self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
446+ ) -> Expression :
447+ return self
448+
449+ def bind_refs (
450+ self ,
451+ bindings : Mapping [ids .ColumnId , Expression ],
452+ allow_partial_bindings : bool = False ,
453+ ) -> Expression :
454+ if self .field .id in bindings .keys ():
455+ return bindings [self .field .id ]
456+ return self
457+
458+ @property
459+ def is_bijective (self ) -> bool :
460+ return True
461+
462+ @property
463+ def is_identity (self ) -> bool :
464+ return True
465+
466+
393467@dataclasses .dataclass (frozen = True )
394468class OpExpression (Expression ):
395469 """An expression representing a scalar operation applied to 1 or more argument sub-expressions."""
@@ -429,13 +503,18 @@ def nullable(self) -> bool:
429503 )
430504 return not null_free
431505
432- def output_type (
433- self , input_types : dict [ids .ColumnId , dtypes .ExpressionType ]
434- ) -> dtypes .ExpressionType :
435- operand_types = tuple (
436- map (lambda x : x .output_type (input_types = input_types ), self .inputs )
437- )
438- return self .op .output_type (* operand_types )
506+ @functools .cached_property
507+ def is_resolved (self ) -> bool :
508+ return all (input .is_resolved for input in self .inputs )
509+
510+ @functools .cached_property
511+ def output_type (self ) -> dtypes .ExpressionType :
512+ if not self .is_resolved :
513+ raise ValueError (f"Type of expression { self .op .name } has not been fixed." )
514+
515+ input_types = [input .output_type for input in self .inputs ]
516+
517+ return self .op .output_type (* input_types )
439518
440519 def bind_variables (
441520 self , bindings : Mapping [str , Expression ], allow_partial_bindings : bool = False
@@ -475,4 +554,22 @@ def deterministic(self) -> bool:
475554 )
476555
477556
557+ def bind_schema_fields (
558+ expr : Expression , field_by_id : Mapping [ids .ColumnId , field .Field ]
559+ ) -> Expression :
560+ """
561+ Updates `DerefOp` expressions by replacing column IDs with actual schema fields(columns).
562+
563+ We can only deduct an expression's output type and nullability after binding schema fields to
564+ all its deref expressions.
565+ """
566+ if expr .is_resolved :
567+ return expr
568+
569+ expr_by_id = {
570+ id : SchemaFieldRefExpression (field ) for id , field in field_by_id .items ()
571+ }
572+ return expr .bind_refs (expr_by_id )
573+
574+
478575RefOrConstant = Union [DerefOp , ScalarConstantExpression ]
0 commit comments