diff --git a/packages/bigframes/bigframes/bigquery/_operations/ml.py b/packages/bigframes/bigframes/bigquery/_operations/ml.py index 04d88c9b6a68..412b49b888f5 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ml.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ml.py @@ -25,6 +25,7 @@ import bigframes.dataframe as dataframe import bigframes.ml.base import bigframes.session +import bigframes.core.col as col from bigframes.bigquery._operations import utils @@ -50,7 +51,9 @@ def create_model( input_schema: Optional[Mapping[str, str]] = None, output_schema: Optional[Mapping[str, str]] = None, connection_name: Optional[str] = None, - options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, + options: Optional[ + Mapping[str, Union[str, int, float, bool, list, "col.Expression"]] + ] = None, training_data: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None, custom_holiday: Optional[Union[pd.DataFrame, dataframe.DataFrame, str]] = None, session: Optional[bigframes.session.Session] = None, @@ -78,7 +81,7 @@ def create_model( The OUTPUT clause, which specifies the schema of the output data. connection_name (str, optional): The connection to use for the model. - options (Mapping[str, Union[str, int, float, bool, list]], optional): + options (Mapping[str, Union[str, int, float, bool, list, bigframes.core.col.Expression]], optional): The OPTIONS clause, which specifies the model options. training_data (Union[bigframes.pandas.DataFrame, str], optional): The query or DataFrame to use for training the model. diff --git a/packages/bigframes/bigframes/core/sql/ml.py b/packages/bigframes/bigframes/core/sql/ml.py index 9694e19ad30e..8d971e6c3e8d 100644 --- a/packages/bigframes/bigframes/core/sql/ml.py +++ b/packages/bigframes/bigframes/core/sql/ml.py @@ -16,7 +16,9 @@ from typing import Any, Dict, List, Mapping, Optional, Union +import bigframes.core.col as col from bigframes.core.compile.sqlglot import sql as sg_sql +from bigframes.core.compile.sqlglot.expression_compiler import expression_compiler def create_model_ddl( @@ -28,7 +30,9 @@ def create_model_ddl( input_schema: Optional[Mapping[str, str]] = None, output_schema: Optional[Mapping[str, str]] = None, connection_name: Optional[str] = None, - options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None, + options: Optional[ + Mapping[str, Union[str, int, float, bool, list, "col.Expression"]] + ] = None, training_data: Optional[str] = None, custom_holiday: Optional[str] = None, ) -> str: @@ -70,7 +74,10 @@ def create_model_ddl( if options: rendered_options = [] for option_name, option_value in options.items(): - if isinstance(option_value, (list, tuple)): + if isinstance(option_value, col.Expression): + sg_expr = expression_compiler.compile_expression(option_value._value) + rendered_val = sg_sql.to_sql(sg_expr) + elif isinstance(option_value, (list, tuple)): # Handle list options like model_registry="vertex_ai" # wait, usually options are key=value. # if value is list, it is [val1, val2] diff --git a/packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql b/packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql new file mode 100644 index 000000000000..e328aab5fbfd --- /dev/null +++ b/packages/bigframes/tests/unit/core/sql/snapshots/test_ml/test_create_model_expression_option/create_model_expression_option.sql @@ -0,0 +1,3 @@ +CREATE MODEL `my_model` +OPTIONS(l2_reg = 0.1 * 10, booster_type = 'gbtree') +AS SELECT * FROM t diff --git a/packages/bigframes/tests/unit/core/sql/test_ml.py b/packages/bigframes/tests/unit/core/sql/test_ml.py index bb3b61a949cf..61296638eec2 100644 --- a/packages/bigframes/tests/unit/core/sql/test_ml.py +++ b/packages/bigframes/tests/unit/core/sql/test_ml.py @@ -14,7 +14,11 @@ import pytest +import bigframes.core.col as col +import bigframes.core.expression as ex import bigframes.core.sql.ml +import bigframes.dtypes as dtypes +import bigframes.operations.numeric_ops as numeric_ops pytest.importorskip("pytest_snapshot") @@ -97,6 +101,26 @@ def test_create_model_list_option(snapshot): snapshot.assert_match(sql, "create_model_list_option.sql") +def test_create_model_expression_option(snapshot): + # An expression that calls a function on a literal value + # e.g. 0.1 * 10 + literal_expr = ex.ScalarConstantExpression(0.1, dtypes.FLOAT_DTYPE) + multiplier_expr = ex.ScalarConstantExpression(10, dtypes.INT_DTYPE) + math_expr = col.Expression( + ex.OpExpression(op=numeric_ops.mul_op, inputs=(literal_expr, multiplier_expr)) + ) + + sql = bigframes.core.sql.ml.create_model_ddl( + model_name="my_model", + options={ + "l2_reg": math_expr, + "booster_type": "gbtree", + }, + training_data="SELECT * FROM t", + ) + snapshot.assert_match(sql, "create_model_expression_option.sql") + + def test_evaluate_model_basic(snapshot): sql = bigframes.core.sql.ml.evaluate( model_name="my_project.my_dataset.my_model",