diff --git a/CHANGELOG.md b/CHANGELOG.md index 335ebbd..c7cb8c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.1.3 - 2026-02 + ### Runner + - Separate writer from runner, sorting schema to align to written table + - Added a debug run option that return dataframe without writing anything + ## 2.1.2 - 2026-01 ### Maker - Feature type normalization to Double and Long diff --git a/pyproject.toml b/pyproject.toml index f7eae2b..4a850c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rialto" -version = "2.1.2" +version = "2.1.3" description = "Rialto is a framework for building and deploying machine learning features in a scalable and reusable way. It provides a set of tools that make it easy to define and deploy features and models, and it provides a way to orchestrate the execution of these features and models." authors = [ { name = "Marek Dobransky", email = "marekdobr@gmail.com" }, diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 875ed7d..541e977 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -91,11 +91,11 @@ def normalize_types(df: DataFrame) -> DataFrame: return df.select( [ - F.col(f.name).cast(DoubleType()) + F.when(F.col(f.name).isNotNull(), F.col(f.name).cast(DoubleType())).otherwise(F.lit(None)).alias(f.name) if isinstance(f.dataType, float_types) - else F.col(f.name).cast(LongType()) + else F.when(F.col(f.name).isNotNull(), F.col(f.name).cast(LongType())).otherwise(F.lit(None)).alias(f.name) if isinstance(f.dataType, int_types) - else F.col(f.name) + else F.when(F.col(f.name).isNotNull(), F.col(f.name)).otherwise(F.lit(None)).alias(f.name) for f in df.schema.fields ] ) diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index e3d945c..2807a77 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -30,6 +30,7 @@ from rialto.runner.reporting.tracker import Tracker from rialto.runner.table import Table from rialto.runner.transformation import Transformation +from rialto.runner.writer import Writer class Runner: @@ -52,7 +53,7 @@ def __init__( self.rerun = rerun self.skip_dependencies = skip_dependencies self.op = op - self.merge_schema = merge_schema + self.writer = Writer(spark, merge_schema=merge_schema) self.tracker = Tracker( mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark ) @@ -96,34 +97,6 @@ def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineC return df - def _create_schema(self, table: Table): - """ - Create schema if it doesn't exist - - :param schema_path: path to schema - """ - self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {table.get_schema_path()}") - - def _write(self, df: DataFrame, info_date: date, table: Table) -> None: - """ - Write dataframe to storage - - :param df: dataframe to write - :param info_date: date to partition - :param table: path to write to - :return: None - """ - self._create_schema(table) - - df = df.withColumn(table.partition, F.lit(info_date)) - if self.merge_schema is True: - df.write.partitionBy(table.partition).mode("overwrite").option("mergeSchema", "true").saveAsTable( - table.get_table_path() - ) - else: - df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) - logger.info(f"Results writen to {table.get_table_path()}") - def _check_written(self, info_date: date, table: Table) -> int: """ Check if there are records written for given date @@ -238,7 +211,7 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat feature_group = utils.load_module(pipeline.module) df = self._execute(feature_group, run_date, pipeline) - self._write(df, info_date, target) + self.writer.write(df, info_date, target) records = self._check_written(info_date, target) logger.info(f"Generated {records} records") if records == 0: @@ -331,3 +304,23 @@ def __call__(self): print(self.tracker.records) self.tracker.report_by_mail() logger.info("Execution finished") + + def debug(self) -> DataFrame: + """Debug mode - run only first op for one date and return the resulting dataframe""" + logger.info("Running in debug mode") + if self.op: + pipeline = [p for p in self.config.pipelines if p.name == self.op][0] + else: + pipeline = self.config.pipelines[0] + + target = Table( + schema_path=pipeline.target.target_schema, + class_name=pipeline.module.python_class, + partition=pipeline.target.target_partition_column, + ) + selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) + if len(selected_run_dates) > 0: + df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) + return self.writer._process(df, selected_info_dates[0], target) + else: + logger.info("No dates to run in debug mode") diff --git a/rialto/runner/writer.py b/rialto/runner/writer.py new file mode 100644 index 0000000..3e31dac --- /dev/null +++ b/rialto/runner/writer.py @@ -0,0 +1,96 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["Writer"] + +from datetime import date +from typing import List + +import pyspark.sql.functions as F +from loguru import logger +from pyspark.sql import DataFrame, SparkSession + +from rialto.runner.table import Table + + +class Writer: + """Supporting class for runner""" + + def __init__(self, spark: SparkSession, merge_schema=False): + self.spark = spark + self.merge_schema = merge_schema + + def _create_schema(self, table: Table): + """ + Create schema if it doesn't exist + + :param schema_path: path to schema + """ + self.spark.sql(f"CREATE SCHEMA IF NOT EXISTS {table.get_schema_path()}") + + def _get_existing_columns(self, table: Table): + """ + Get existing schema of table if it exists + + :param table: table to check for + :return: existing columns or None + """ + try: + return self.spark.table(table.get_table_path()).columns + except Exception as e: + logger.warning(f"Could not get existing schema for {table.get_table_path()}: {e}") + return None + + def _align_schema(self, df: DataFrame, existing_columns: List) -> DataFrame: + """ + Align schema of dataframe to existing schema of table if it exists + + :param df: dataframe to align + :param table: table to check for existing schema + :return: dataframe with aligned schema + """ + if existing_columns is not None: + return df.select( + *[F.col(c) for c in existing_columns if c in df.columns], + *[F.col(c) for c in df.columns if c not in existing_columns], + ) + return df + + def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame: + df = df.withColumn(table.partition, F.lit(info_date)) + + df = self._align_schema(df, self._get_existing_columns(table)) + + return df + + def write(self, df: DataFrame, info_date: date, table: Table) -> None: + """ + Write dataframe to storage + + :param df: dataframe to write + :param info_date: date to partition + :param table: path to write to + :return: None + """ + self._create_schema(table) + + df = self._process(df, info_date, table) + + if self.merge_schema is True: + df.write.partitionBy(table.partition).mode("overwrite").option("mergeSchema", "true").saveAsTable( + table.get_table_path() + ) + else: + df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) + logger.info(f"Results writen to {table.get_table_path()}") diff --git a/uv.lock b/uv.lock index 7872703..eae6d1f 100644 --- a/uv.lock +++ b/uv.lock @@ -1078,7 +1078,7 @@ wheels = [ [[package]] name = "rialto" -version = "2.1.2" +version = "2.1.3" source = { editable = "." } dependencies = [ { name = "delta-spark" },