diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a3c927..e34e1fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,9 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.2.1 - 2026-05-31 + ### Runner + - added option to use day="last" for monthly scheduling ## 2.2.0 - 2026-03 ### Runner diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 228d59b..0d8645c 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -71,6 +71,16 @@ def get_table( """ raise NotImplementedError + @abc.abstractmethod + def table_exists(self, table: str) -> bool: + """ + Check table exists in storage + + :param table: full table path + :return: bool + """ + raise NotImplementedError + class TableReader(DataReader): """An implementation of data reader for databricks tables""" @@ -165,3 +175,12 @@ def get_table( if uppercase_columns: df = self._uppercase_column_names(df) return df + + def table_exists(self, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return self.spark.catalog.tableExists(table) diff --git a/rialto/metadata/utils.py b/rialto/metadata/utils.py index 0cb591c..efbbf76 100644 --- a/rialto/metadata/utils.py +++ b/rialto/metadata/utils.py @@ -17,7 +17,7 @@ def class_to_catalog_name(class_name) -> str: """ - Map python class name of feature group (CammelCase) to databricks compatible format (lowercase with underscores) + Map python class name of feature group (CamelCase) to databricks compatible format (lowercase with underscores) :param class_name: Python class name :return: feature storage name diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index 7978ac5..1aa5fba 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -13,12 +13,12 @@ # limitations under the License. __all__ = [ - "get_pipelines_config", + "ConfigLoader", ] -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from rialto.common.utils import load_yaml from rialto.runner.config_overrides import override_config @@ -35,8 +35,10 @@ class IntervalConfig(BaseConfig): class ScheduleConfig(BaseConfig): frequency: str - day: Optional[int] = 0 - info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0) + day: Optional[Union[int, str]] = 0 + info_date_shift: Optional[Union[IntervalConfig, List[IntervalConfig]]] = Field( + default_factory=lambda: IntervalConfig(units="days", value=0) + ) class DependencyConfig(BaseConfig): @@ -88,16 +90,16 @@ class PipelineConfig(BaseConfig): name: str module: ModuleConfig schedule: ScheduleConfig - dependencies: Optional[List[DependencyConfig]] = [] - target: TargetConfig = None + dependencies: Optional[List[DependencyConfig]] = Field(default_factory=list) + target: Optional[TargetConfig] = None metadata_manager: Optional[MetadataManagerConfig] = None feature_loader: Optional[FeatureLoaderConfig] = None - extras: Optional[Dict] = {} + extras: Optional[Dict] = Field(default_factory=dict) class PipelinesConfig(BaseConfig): runner: RunnerConfig - pipelines: list[PipelineConfig] + pipelines: List[PipelineConfig] def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: @@ -108,3 +110,12 @@ def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: return PipelinesConfig(**cfg) else: return PipelinesConfig(**raw_config) + + +class ConfigLoader: + """Loader for pipelines config""" + + @staticmethod + def load_yaml(path: str, overrides: Dict) -> PipelinesConfig: + """Load yaml config and apply overrides""" + return get_pipelines_config(path, overrides) diff --git a/rialto/runner/data_checker.py b/rialto/runner/data_checker.py new file mode 100644 index 0000000..787c56a --- /dev/null +++ b/rialto/runner/data_checker.py @@ -0,0 +1,103 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["DataChecker"] + +from datetime import date +from typing import Dict + +from loguru import logger +from pyspark.sql import DataFrame + +from rialto.common import DataReader +from rialto.runner.table import Table + + +class DataChecker: + """Checks if data for given date or date range is present in storage""" + + def __init__(self, reader: DataReader): + self.reader = reader + + def check_date(self, target: Table, partition_date: date) -> bool: + """Check if data for given date is present in target + + :param target: target Table to check + :param partition_date: Date to check + :return: True if data for given date is present, False otherwise + """ + return self.check_range(target, partition_date, partition_date) + + def check_range(self, target: Table, start_date: date, end_date: date) -> bool: + """Check if data for given date range is present in target + + :param target: target Table to check + :param start_date: Starting date of the range to check + :param end_date: Ending date of the range to check + :return: True if data for given date range is present, False otherwise + """ + if self.reader.table_exists(target.get_table_path()): + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=start_date, + date_to=end_date, + filters=target.filters, + ) + data_exists = df.count() > 0 + if ( + data_exists + and (target.filters is None or target.filters == {}) + and target.secondary_partitions is not None + ): + logger.warning( + f"Overwriting {target.get_table_path()} completion status for {start_date} due to presence of " + f"secondary partitions and no filters." + ) + data_exists = False + return data_exists + else: + logger.warning(f"Target table {target.get_table_path()} doesn't exist yet.") + return False + + def _get_filters(self, target: Table, df: DataFrame) -> Dict: + if target.filters is not None: + return target.filters + elif target.secondary_partitions: + filters = {} + logger.info("Inferring target sub-partition values from generated data.") + row = df.select(*target.secondary_partitions).distinct().collect()[0] + for c in target.secondary_partitions: + filters[c] = row[c] + return filters + else: + return {} + + def check_written(self, target: Table, partition_date: date, df: DataFrame) -> int: + """Check how many records were written + + :param target: target Table to check + :param partition_date: Date to check + :param df: DataFrame that was written, used to determine filters if not provided in config + :return: Number of records for given date + """ + filters = self._get_filters(target, df) + df = self.reader.get_table( + target.get_table_path(), + date_column=target.partition, + date_from=partition_date, + date_to=partition_date, + filters=filters, + ) + + return df.count() diff --git a/rialto/runner/date_manager.py b/rialto/runner/date_manager.py index 1bcef7b..9b3422b 100644 --- a/rialto/runner/date_manager.py +++ b/rialto/runner/date_manager.py @@ -18,13 +18,40 @@ from typing import List from dateutil.relativedelta import relativedelta +from loguru import logger -from rialto.runner.config_loader import ScheduleConfig +from rialto.runner.config_loader import RunnerConfig, ScheduleConfig class DateManager: """Date generation and shifts based on configuration""" + def __init__(self, config: RunnerConfig, run_date: str = None): + if run_date: + run_date = self.str_to_date(run_date) + else: + run_date = date.today() + + self.date_from = self.date_subtract( + input_date=run_date, + units=config.watched_period_units, + value=config.watched_period_value, + ) + + self.date_until = run_date + + if self.date_from > self.date_until: + raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") + logger.info(f"Running period set to: {self.date_from} - {self.date_until}") + + def get_date_from(self) -> date: + """Get starting date of the execution window""" + return self.date_from + + def get_date_until(self) -> date: + """Get ending date of the execution window""" + return self.date_until + @staticmethod def str_to_date(str_date: str) -> date: """ @@ -33,26 +60,29 @@ def str_to_date(str_date: str) -> date: :param str_date: string date :return: date """ - return datetime.strptime(str_date, "%Y-%m-%d").date() + try: + return datetime.strptime(str_date, "%Y-%m-%d").date() + except ValueError: + raise ValueError(f"Invalid date format: {str_date}. Expected YYYY-MM-DD.") @staticmethod - def date_subtract(run_date: date, units: str, value: int) -> date: + def date_subtract(input_date: date, units: str, value: int) -> date: """ - Generate starting date from given date and config + Subtract given number of units from input date - :param run_date: base date + :param input_date: base date :param units: units: years, months, weeks, days :param value: number of units to subtract :return: Starting date """ if units == "years": - return run_date - relativedelta(years=value) + return input_date - relativedelta(years=value) if units == "months": - return run_date - relativedelta(months=value) + return input_date - relativedelta(months=value) if units == "weeks": - return run_date - relativedelta(weeks=value) + return input_date - relativedelta(weeks=value) if units == "days": - return run_date - relativedelta(days=value) + return input_date - relativedelta(days=value) raise ValueError(f"Unknown time unit {units}") @staticmethod @@ -69,27 +99,39 @@ def all_dates(date_from: date, date_to: date) -> List[date]: return [date_from + relativedelta(days=n) for n in range((date_to - date_from).days + 1)] - @staticmethod - def run_dates(date_from: date, date_to: date, schedule: ScheduleConfig) -> List[date]: + def get_execution_and_partition_dates(self, schedule: ScheduleConfig) -> List[tuple[date, date]]: + """ + Get list of execution and partition dates for given configuration + + :return: List of tuples with execution and partition dates + """ + execution_dates = self._execution_dates(schedule) + return [(ex_date, self._to_partition_date(ex_date, schedule)) for ex_date in execution_dates] + + def _execution_dates(self, schedule: ScheduleConfig) -> List[date]: """ Select dates inside given interval depending on frequency and selected day - :param date_from: interval start - :param date_to: interval end :param schedule: schedule config - :return: list of dates + :return: List of execution dates """ - options = DateManager.all_dates(date_from, date_to) - if schedule.frequency == "daily": + options = self.all_dates(self.date_from, self.date_until) + frequency = schedule.frequency.lower() + if frequency == "daily": return options - if schedule.frequency == "weekly": + if frequency == "weekly": + if not (1 <= schedule.day <= 7): + raise ValueError(f"Invalid day for weekly frequency: {schedule.day}. Must be 1-7.") return [x for x in options if x.isoweekday() == schedule.day] - if schedule.frequency == "monthly": + if frequency == "monthly": + if schedule.day == "last": + return [x for x in options if (x + relativedelta(days=1)).month != x.month] + if not (1 <= schedule.day <= 31): + raise ValueError(f"Invalid day for monthly frequency: {schedule.day}. Must be 1-31 or last.") return [x for x in options if x.day == schedule.day] - raise ValueError(f"Unknown frequency {schedule.frequency}") + raise ValueError(f"Unknown frequency: {schedule.frequency}") - @staticmethod - def to_info_date(date: date, schedule: ScheduleConfig) -> date: + def _to_partition_date(self, date: date, schedule: ScheduleConfig) -> date: """ Shift given date according to config @@ -97,11 +139,9 @@ def to_info_date(date: date, schedule: ScheduleConfig) -> date: :param schedule: schedule config :return: date """ - if isinstance(schedule.info_date_shift, List): + if isinstance(schedule.info_date_shift, list): for shift in schedule.info_date_shift: - date = DateManager.date_subtract(date, units=shift.units, value=shift.value) + date = self.date_subtract(date, units=shift.units, value=shift.value) else: - date = DateManager.date_subtract( - date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value - ) + date = self.date_subtract(date, units=schedule.info_date_shift.units, value=schedule.info_date_shift.value) return date diff --git a/rialto/runner/executor.py b/rialto/runner/executor.py new file mode 100644 index 0000000..4ae1644 --- /dev/null +++ b/rialto/runner/executor.py @@ -0,0 +1,55 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["PipelineExecutor"] + +from loguru import logger +from pyspark.sql import DataFrame, SparkSession + +import rialto.runner.utils as utils +from rialto.common import DataReader +from rialto.runner.data_checker import DataChecker +from rialto.runner.task_registry import PipelineTask + + +class PipelineExecutor: + """Executes a single pipeline task.""" + + def __init__(self, spark: SparkSession, reader: DataReader, checker: DataChecker): + self.spark = spark + self.reader = reader + self.checker = checker + + @logger.catch + def execute(self, pipeline: PipelineTask) -> DataFrame: + """ + Execute the pipeline task. + + :param pipeline: Pipeline object to execute. + :return: DataFrame resulting from pipeline execution. + """ + logger.info(f"Executing pipeline {pipeline.op} for partition date {pipeline.partition_date}") + + # Load and run the job + job = utils.load_module(pipeline.config.module) + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline.config) + df = job.run( + spark=self.spark, + run_date=pipeline.execution_date, + config=pipeline.config, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, + ) + return df diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 384998f..826f8bc 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -14,22 +14,19 @@ __all__ = ["Runner"] -import datetime -from datetime import date -from typing import Dict, List, Tuple +from typing import Dict, List -from loguru import logger from pyspark.sql import DataFrame, SparkSession -import rialto.runner.utils as utils from rialto.common import TableReader -from rialto.runner.config_loader import PipelineConfig, get_pipelines_config +from rialto.runner.config_loader import ConfigLoader, PipelineConfig +from rialto.runner.data_checker import DataChecker from rialto.runner.date_manager import DateManager -from rialto.runner.reporting.record import Record +from rialto.runner.executor import PipelineExecutor from rialto.runner.reporting.tracker import Tracker -from rialto.runner.table import Table -from rialto.runner.transformation import Transformation -from rialto.runner.writer import Writer +from rialto.runner.task_registry import TaskRegistry +from rialto.runner.task_status_checker import TaskStatusChecker +from rialto.runner.writer import DatabricksWriter class Runner: @@ -46,311 +43,75 @@ def __init__( overrides: Dict = None, merge_schema: bool = False, ): - self.spark = spark - self.config = get_pipelines_config(config_path, overrides) - self.reader = TableReader(spark) + self.config = ConfigLoader().load_yaml(config_path, overrides) + self.date_manager = DateManager(self.config.runner, run_date) self.rerun = rerun - self.skip_dependencies = skip_dependencies self.op = op - self.writer = Writer(spark, merge_schema=merge_schema) + self.skip_dependencies = skip_dependencies + self.writer = DatabricksWriter(spark, merge_schema=merge_schema) + + reader = TableReader(spark) + data_checker = DataChecker(reader) + self.task_checker = TaskStatusChecker(data_checker) + self.registry = TaskRegistry(spark, date_manager=self.date_manager) + self.executor = PipelineExecutor( + spark=spark, + reader=reader, + checker=data_checker, + ) self.tracker = Tracker( mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark ) - if run_date: - run_date = DateManager.str_to_date(run_date) - else: - run_date = date.today() - - self.date_from = DateManager.date_subtract( - run_date=run_date, - units=self.config.runner.watched_period_units, - value=self.config.runner.watched_period_value, - ) - - self.date_until = run_date - - if self.date_from > self.date_until: - raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") - logger.info(f"Running period set to: {self.date_from} - {self.date_until}") - - def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: - """ - Run the job - - :param instance: Instance of Transformation - :param run_date: date to run for - :param pipeline: pipeline configuration - :return: Dataframe - """ - metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) - - df = instance.run( - spark=self.spark, - run_date=run_date, - config=pipeline, - reader=self.reader, - metadata_manager=metadata_manager, - feature_loader=feature_loader, - ) - - return df - - def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: - """ - Check if there are records written for given date - - :param info_date: date to check - :param table: target table object - :return: number of records - """ - filters = {} - if pipeline.target.rerun_filters is not None: - filters = pipeline.target.rerun_filters + def _select_pipelines(self) -> List[PipelineConfig]: + """Select pipelines to run based on config and input parameters""" + if self.op: + selected = [p for p in self.config.pipelines if p.name == self.op] + if len(selected) < 1: + raise ValueError(f"Unknown operation selected: {self.op}") + return selected else: - if table.secondary_partitions: - row = df.select(*table.secondary_partitions).distinct().collect()[0] - for c in table.secondary_partitions: - val = row[0][c] - filters[c] = val - - df = self.reader.get_table( - table.get_table_path(), date_column=table.partition, date_from=info_date, date_to=info_date, filters=filters - ) + return self.config.pipelines - return df.count() - - def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: - """ - For given list of dates, check if there is a matching partition for each - - :param table: Table object - :param dates: list of dates to check - :return: list of bool - """ - if utils.table_exists(self.spark, table.get_table_path()): - checks = [] - for check_date in dates: - df = self.reader.get_table( - table.get_table_path(), - date_column=table.partition, - date_from=check_date, - date_to=check_date, - filters=target_filters, + def _register_tasks(self, pipelines: List[PipelineConfig]) -> None: + for pipeline in pipelines: + for exec_date, partition_date in self.date_manager.get_execution_and_partition_dates(pipeline.schedule): + self.registry.add_task( + name=pipeline.name, execution_date=exec_date, partition_date=partition_date, config=pipeline ) - data_exists = df.count() > 0 - if data_exists and target_filters is None and table.secondary_partitions is not None: - # ensure rerun if the write consideres secondary partitions but the filter doesn't - data_exists = False - checks.append(data_exists) - return checks - else: - logger.info(f"Table {table.get_table_path()} doesn't exist!") - return [False for _ in dates] - - def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: - """ - Check for all dependencies in config if they have available partitions - - :param pipeline: configuration - :param run_date: run date - :return: bool - """ - logger.info(f"{pipeline.name} checking dependencies for {run_date}") - - error = "" - - for dependency in pipeline.dependencies: - dep_from = DateManager.date_subtract(run_date, dependency.interval.units, dependency.interval.value) - logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}") - - possible_dep_dates = DateManager.all_dates(dep_from, run_date) - - logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") - - source = Table(table_path=dependency.table, partition=dependency.date_col) - if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): - logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") - else: - msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" - logger.info(msg) - error = error + msg + "\n" - if error != "": - self.tracker.last_error = error - return False - - return True - - def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: - """ - Check if model has run for given dates - - :param target_path: Table object - :param info_dates: list of dates - :return: bool list - """ - if self.rerun: - return [False for _ in info_dates] - else: - return self.check_dates_have_data(target, info_dates, filters) - - def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: - """ - Select run dates and info dates based on completion - - :param pipeline: pipeline config - :param table: table path - :return: list of run dates and list of info dates - """ - possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) - possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] - current_state = self._get_completion(table, possible_info_dates, filters) - - selection = [ - (run, info) for run, info, state in zip(possible_run_dates, possible_info_dates, current_state) if not state - ] - - if not len(selection): - logger.info(f"{pipeline.name} has no dates to run") - return [], [] - - selected_run_dates, selected_info_dates = zip(*selection) - logger.info(f"{pipeline.name} identified to run for {selected_run_dates}") - - return list(selected_run_dates), list(selected_info_dates) - - def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int: - """ - Run one pipeline for one date - - :param pipeline: pipeline cfg - :param run_date: run date - :param info_date: information date - :param target: target Table - :return: success bool - """ - if self.skip_dependencies or self.check_dependencies(pipeline, run_date): - logger.info(f"Running {pipeline.name} for {run_date}") - - feature_group = utils.load_module(pipeline.module) - df = self._execute(feature_group, run_date, pipeline) - self.writer.write(df, info_date, target) - records = self._check_written(info_date, target, df, pipeline) - logger.info(f"Generated {records} records") - if records == 0: - raise RuntimeError("No records generated") - else: - return records - return 0 - - def _run_pipeline(self, pipeline: PipelineConfig): - """ - Run single pipeline for all required dates - - :param pipeline: pipeline cfg - :return: success bool - """ - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - logger.info(f"Loaded pipeline {pipeline.name}") - - selected_run_dates, selected_info_dates = self._select_run_dates( - pipeline, target, pipeline.target.rerun_filters - ) - - # ----------- Checking dependencies available ---------- - for run_date, info_date in zip(selected_run_dates, selected_info_dates): - run_start = datetime.datetime.now() - try: - records = self._run_one_date(pipeline, run_date, info_date, target) - if records > 0: - status = "Success" - message = "" - else: - status = "Failure" - message = self.tracker.last_error - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=records, - status=status, - reason=message, - ) - ) - except Exception as error: - logger.error(f"An exception occurred in pipeline {pipeline.name}") - logger.exception(error) - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Exception", - exception=str(error), - ) - ) - except KeyboardInterrupt: - logger.error(f"Pipeline {pipeline.name} interrupted") - self.tracker.add( - Record( - job=pipeline.name, - target=target.get_table_path(), - date=info_date, - time=datetime.datetime.now() - run_start, - records=0, - status="Error", - reason="Interrupted by user", - ) - ) - raise KeyboardInterrupt + def _check_tasks(self) -> None: + for task in self.registry.tasks: + if not self.rerun: + self.task_checker.check_completion(task) + if not self.skip_dependencies: + self.task_checker.check_pipeline_dependencies(task) + + def _run_tasks(self) -> None: + for task in self.registry.tasks: + if (not task.completion or self.rerun) and (task.dependencies_complete or self.skip_dependencies): + # run_start = datetime.now() + df = self.executor.execute(task) + self.writer.write(df, task.partition_date, task.target) + # records = self.checker.check_written(task.target, task.partition_date, df) def __call__(self): """Execute pipelines""" - logger.info("Executing pipelines") - try: - if self.op: - selected = [p for p in self.config.pipelines if p.name == self.op] - if len(selected) < 1: - raise ValueError(f"Unknown operation selected: {self.op}") - self._run_pipeline(selected[0]) - else: - for pipeline in self.config.pipelines: - self._run_pipeline(pipeline) - finally: - print(self.tracker.records) - self.tracker.report_by_mail() - logger.info("Execution finished") + pipelines = self._select_pipelines() + self._register_tasks(pipelines) + self._check_tasks() + self.registry.log_status() + self._run_tasks() + + def dry_run(self): + """Dry run - log status of pipelines without executing""" + pipelines = self._select_pipelines() + self._register_tasks(pipelines) + self._check_tasks() + self.registry.log_status() def debug(self) -> DataFrame: """Debug mode - run only first op for one date and return the resulting dataframe""" - logger.info("Running in debug mode") - if self.op: - pipeline = [p for p in self.config.pipelines if p.name == self.op][0] - else: - pipeline = self.config.pipelines[0] - - target = Table( - schema_path=pipeline.target.target_schema, - class_name=pipeline.module.python_class, - partition=pipeline.target.target_partition_column, - secondary_partitions=pipeline.target.secondary_partition_columns, - table=pipeline.target.custom_name, - ) - selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) - if len(selected_run_dates) > 0: - df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) - return self.writer._process(df, selected_info_dates[0], target) - else: - logger.info("No dates to run in debug mode") + pipelines = self._select_pipelines() + self._register_tasks(pipelines) + return self.executor.execute(self.registry.tasks[0]) diff --git a/rialto/runner/runner_old.py b/rialto/runner/runner_old.py new file mode 100644 index 0000000..4265e26 --- /dev/null +++ b/rialto/runner/runner_old.py @@ -0,0 +1,342 @@ +# # Copyright 2022 ABSA Group Limited +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# +# __all__ = ["Runner"] +# +# import datetime +# from datetime import date +# from typing import Dict, List, Tuple +# +# from loguru import logger +# from pyspark.sql import DataFrame, SparkSession +# +# import rialto.runner.utils as utils +# from rialto.common import TableReader +# from rialto.runner.config_loader import ConfigLoader, PipelineConfig +# from rialto.runner.date_manager import DateManager +# from rialto.runner.reporting.record import Record +# from rialto.runner.reporting.tracker import Tracker +# from rialto.runner.table import Table +# from rialto.runner.transformation import Transformation +# from rialto.runner.writer import Writer +# +# +# class Runner: +# """A scheduler and dependency checker for feature runs""" +# +# def __init__( +# self, +# spark: SparkSession, +# config_path: str, +# run_date: str = None, +# rerun: bool = False, +# op: str = None, +# skip_dependencies: bool = False, +# overrides: Dict = None, +# merge_schema: bool = False, +# ): +# self.spark = spark +# self.config = ConfigLoader().load_yaml(config_path, overrides) +# self.reader = TableReader(spark) +# self.rerun = rerun +# self.skip_dependencies = skip_dependencies +# self.op = op +# self.writer = Writer(spark, merge_schema=merge_schema) +# self.tracker = Tracker( +# mail_cfg=self.config.runner.mail, bookkeeping=self.config.runner.bookkeeping, spark=spark +# ) +# self.date_manager = DateManager(self.config, run_date) +# +# def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: +# """ +# Run the job +# +# :param instance: Instance of Transformation +# :param run_date: date to run for +# :param pipeline: pipeline configuration +# :return: Dataframe +# """ +# metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) +# +# df = instance.run( +# spark=self.spark, +# run_date=run_date, +# config=pipeline, +# reader=self.reader, +# metadata_manager=metadata_manager, +# feature_loader=feature_loader, +# ) +# +# return df +# +# def _check_written(self, info_date: date, table: Table, df: DataFrame, pipeline: PipelineConfig) -> int: +# """ +# Check if there are records written for given date +# +# :param info_date: date to check +# :param table: target table object +# :return: number of records +# """ +# filters = {} +# if pipeline.target.rerun_filters is not None: +# filters = pipeline.target.rerun_filters +# else: +# if table.secondary_partitions: +# row = df.select(*table.secondary_partitions).distinct().collect()[0] +# for c in table.secondary_partitions: +# val = row[0][c] +# filters[c] = val +# +# df = self.reader.get_table( +# table.get_table_path(), date_column=table.partition, +# date_from=info_date, date_to=info_date, filters=filters +# ) +# +# return df.count() +# +# def check_dates_have_data(self, table: Table, dates: List[date], target_filters: Dict = None) -> List[bool]: +# """ +# For given list of dates, check if there is a matching partition for each +# +# :param table: Table object +# :param dates: list of dates to check +# :return: list of bool +# """ +# if self.reader.table_exists(self.spark, table.get_table_path()): +# checks = [] +# for check_date in dates: +# df = self.reader.get_table( +# table.get_table_path(), +# date_column=table.partition, +# date_from=check_date, +# date_to=check_date, +# filters=target_filters, +# ) +# data_exists = df.count() > 0 +# if data_exists and target_filters is None and table.secondary_partitions is not None: +# # ensure rerun if the write consideres secondary partitions but the filter doesn't +# data_exists = False +# checks.append(data_exists) +# return checks +# else: +# logger.info(f"Table {table.get_table_path()} doesn't exist!") +# return [False for _ in dates] +# +# def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: +# """ +# Check for all dependencies in config if they have available partitions +# +# :param pipeline: configuration +# :param run_date: run date +# :return: bool +# """ +# logger.info(f"{pipeline.name} checking dependencies for {run_date}") +# +# error = "" +# +# for dependency in pipeline.dependencies: +# dep_from = date_manager.date_subtract(run_date, dependency.interval.units, dependency.interval.value) +# logger.info(f"Looking for {dependency.table} from {dep_from} until {run_date}") +# +# possible_dep_dates = date_manager.all_dates(dep_from, run_date) +# +# logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") +# +# source = Table(table_path=dependency.table, partition=dependency.date_col) +# if True in self.check_dates_have_data(source, possible_dep_dates, dependency.filters): +# logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") +# else: +# msg = f"Missing dependency for {dependency.table} from {dep_from} until {run_date}" +# logger.info(msg) +# error = error + msg + "\n" +# +# if error != "": +# self.tracker.last_error = error +# return False +# +# return True +# +# def _get_completion(self, target: Table, info_dates: List[date], filters: Dict = None) -> List[bool]: +# """ +# Check if model has run for given dates +# +# :param target_path: Table object +# :param info_dates: list of dates +# :return: bool list +# """ +# if self.rerun: +# return [False for _ in info_dates] +# else: +# return self.check_dates_have_data(target, info_dates, filters) +# +# def _select_run_dates(self, pipeline: PipelineConfig, table: Table, filters: Dict = None) -> Tuple[List, List]: +# """ +# Select run dates and info dates based on completion +# +# :param pipeline: pipeline config +# :param table: table path +# :return: list of run dates and list of info dates +# """ +# possible_run_dates = date_manager.execution_dates(pipeline.schedule) +# possible_info_dates = [DateManager.to_partition_date(x, pipeline.schedule) for x in possible_run_dates] +# current_state = self._get_completion(table, possible_info_dates, filters) +# +# selection = [ +# (run, info) for run, info, state in zip(possible_run_dates, +# possible_info_dates, current_state) if not state +# ] +# +# if not len(selection): +# logger.info(f"{pipeline.name} has no dates to run") +# return [], [] +# +# selected_run_dates, selected_info_dates = zip(*selection) +# logger.info(f"{pipeline.name} identified to run for {selected_run_dates}") +# +# return list(selected_run_dates), list(selected_info_dates) +# +# def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: date, target: Table) -> int: +# """ +# Run one pipeline for one date +# +# :param pipeline: pipeline cfg +# :param run_date: run date +# :param info_date: information date +# :param target: target Table +# :return: success bool +# """ +# if self.skip_dependencies or self.check_dependencies(pipeline, run_date): +# logger.info(f"Running {pipeline.name} for {run_date}") +# +# feature_group = utils.load_module(pipeline.module) +# df = self._execute(feature_group, run_date, pipeline) +# self.writer.write(df, info_date, target) +# records = self._check_written(info_date, target, df, pipeline) +# logger.info(f"Generated {records} records") +# if records == 0: +# raise RuntimeError("No records generated") +# else: +# return records +# return 0 +# +# def _run_pipeline(self, pipeline: PipelineConfig): +# """ +# Run single pipeline for all required dates +# +# :param pipeline: pipeline cfg +# :return: success bool +# """ +# target = Table( +# schema_path=pipeline.target.target_schema, +# class_name=pipeline.module.python_class, +# partition=pipeline.target.target_partition_column, +# secondary_partitions=pipeline.target.secondary_partition_columns, +# table=pipeline.target.custom_name, +# ) +# logger.info(f"Loaded pipeline {pipeline.name}") +# +# selected_run_dates, selected_info_dates = self._select_run_dates( +# pipeline, target, pipeline.target.rerun_filters +# ) +# +# # ----------- Checking dependencies available ---------- +# for run_date, info_date in zip(selected_run_dates, selected_info_dates): +# run_start = datetime.datetime.now() +# try: +# records = self._run_one_date(pipeline, run_date, info_date, target) +# if records > 0: +# status = "Success" +# message = "" +# else: +# status = "Failure" +# message = self.tracker.last_error +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=records, +# status=status, +# reason=message, +# ) +# ) +# except Exception as error: +# logger.error(f"An exception occurred in pipeline {pipeline.name}") +# logger.exception(error) +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=0, +# status="Error", +# reason="Exception", +# exception=str(error), +# ) +# ) +# except KeyboardInterrupt: +# logger.error(f"Pipeline {pipeline.name} interrupted") +# self.tracker.add( +# Record( +# job=pipeline.name, +# target=target.get_table_path(), +# date=info_date, +# time=datetime.datetime.now() - run_start, +# records=0, +# status="Error", +# reason="Interrupted by user", +# ) +# ) +# raise KeyboardInterrupt +# +# def __call__(self): +# """Execute pipelines""" +# logger.info("Executing pipelines") +# try: +# if self.op: +# selected = [p for p in self.config.pipelines if p.name == self.op] +# if len(selected) < 1: +# raise ValueError(f"Unknown operation selected: {self.op}") +# self._run_pipeline(selected[0]) +# else: +# for pipeline in self.config.pipelines: +# self._run_pipeline(pipeline) +# finally: +# print(self.tracker.records) +# self.tracker.report_by_mail() +# logger.info("Execution finished") +# +# def debug(self) -> DataFrame: +# """Debug mode - run only first op for one date and return the resulting dataframe""" +# logger.info("Running in debug mode") +# if self.op: +# pipeline = [p for p in self.config.pipelines if p.name == self.op][0] +# else: +# pipeline = self.config.pipelines[0] +# +# target = Table( +# schema_path=pipeline.target.target_schema, +# class_name=pipeline.module.python_class, +# partition=pipeline.target.target_partition_column, +# secondary_partitions=pipeline.target.secondary_partition_columns, +# table=pipeline.target.custom_name, +# ) +# selected_run_dates, selected_info_dates = self._select_run_dates(pipeline, target) +# if len(selected_run_dates) > 0: +# df = self._execute(utils.load_module(pipeline.module), selected_run_dates[0], pipeline) +# return self.writer._process(df, selected_info_dates[0], target) +# else: +# logger.info("No dates to run in debug mode") diff --git a/rialto/runner/table.py b/rialto/runner/table.py index 2d44498..f97d001 100644 --- a/rialto/runner/table.py +++ b/rialto/runner/table.py @@ -14,14 +14,48 @@ __all__ = ["Table"] -from typing import List +from typing import Dict, List from rialto.metadata import class_to_catalog_name +from rialto.runner.config_loader import DependencyConfig, PipelineConfig class Table: """Handler for databricks catalog paths""" + @classmethod + def from_target_config(cls, config: PipelineConfig) -> "Table": + """ + Create table object from pipeline config target section + + :param config: Pipeline configuration + + :return: Table object + """ + return cls( + schema_path=config.target.target_schema, + class_name=config.module.python_class, + partition=config.target.target_partition_column, + secondary_partitions=config.target.secondary_partition_columns, + table=config.target.custom_name, + filters=config.target.rerun_filters, + ) + + @classmethod + def from_dependency_config(cls, config: DependencyConfig) -> "Table": + """ + Create table object from pipeline config dependency section + + :param config: Dependency configuration + + :return: Table object + """ + return cls( + table_path=config.table, + partition=config.date_col, + filters=config.filters, + ) + def __init__( self, catalog: str = None, @@ -32,12 +66,14 @@ def __init__( class_name: str = None, partition: str = None, secondary_partitions: List[str] = None, + filters: Dict = None, ): self.catalog = catalog self.schema = schema self.table = table self.partition = partition self.secondary_partitions = secondary_partitions + self.filters = filters if schema_path: schema_path = schema_path.split(".") self.catalog = schema_path[0] @@ -58,7 +94,7 @@ def get_table_path(self) -> str: """Get full table path""" return f"{self.catalog}.{self.schema}.{self.table}" - def get_all_partitions(self) -> List[str]: + def get_all_partition_columns(self) -> List[str]: """Get list of all partitions""" if self.secondary_partitions: return [self.partition] + self.secondary_partitions diff --git a/rialto/runner/task_registry.py b/rialto/runner/task_registry.py new file mode 100644 index 0000000..6f0b415 --- /dev/null +++ b/rialto/runner/task_registry.py @@ -0,0 +1,102 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["TaskRegistry", "PipelineTask", "PipelineDependency"] + +from dataclasses import dataclass, field +from datetime import date +from typing import Iterator, List + +from loguru import logger +from pyspark.sql import SparkSession + +from rialto.runner.config_loader import PipelineConfig +from rialto.runner.date_manager import DateManager +from rialto.runner.table import Table + + +@dataclass +class PipelineDependency: + """Class representing a pipeline dependency, with associated table and date range for checking completion""" + + table: Table + date_from: date + date_until: date + complete: bool = False + + +@dataclass +class PipelineTask: + """Class representing a pipeline to be executed.""" + + op: str + execution_date: date + partition_date: date + config: PipelineConfig + target: Table + dependencies: List[PipelineDependency] = field(default_factory=list) + completion: bool = False + dependencies_complete: bool = False + + +class TaskRegistry: + """Registry for pipeline tasks to be executed""" + + def __init__(self, spark: SparkSession, date_manager: DateManager): + self.spark = spark + self.date_manager = date_manager + self.tasks = [] + + def add_task(self, name: str, execution_date: date, partition_date: date, config: PipelineConfig) -> None: + """ + Add task to registry + + :param name: Name of the pipeline + :param execution_date: Date when the pipeline is scheduled to run + :param partition_date: Date for which the pipeline is processing data + :param config: PipelineConfig object with pipeline configuration + + :return: None, adds a Pipeline object to self.tasks + """ + target = Table.from_target_config(config) + new_pipe = PipelineTask( + op=name, execution_date=execution_date, partition_date=partition_date, config=config, target=target + ) + + for dependency_config in config.dependencies: + dependency_table = Table.from_dependency_config(dependency_config) + dependency_from = self.date_manager.date_subtract( + execution_date, dependency_config.interval.units, dependency_config.interval.value + ) + dependency = PipelineDependency( + table=dependency_table, date_from=dependency_from, date_until=execution_date + ) + new_pipe.dependencies.append(dependency) + + self.tasks.append(new_pipe) + + def __iter__(self) -> Iterator[PipelineTask]: + """Allow iteration over tasks in execution plan""" + return iter(self.tasks) + + def log_status(self) -> None: + """Log status of all tasks in registry, showing completion and dependency status""" + check = "\u2714" # ✔ + cross = "\u2718" # ✘ + status = f"\n{'Job Name':<50} {'Partition Date':<15} {'Complete':<8} {'Dependencies':<12}\n" + status = status + ("-" * 70 + "\n") + for task in self.tasks: + complete_icon = check if task.completion else cross + deps_icon = check if task.dependencies_complete else cross + status = status + f"{task.op:<50} {str(task.partition_date):<15} {complete_icon:^8} {deps_icon:^12}\n" + logger.info(status) diff --git a/rialto/runner/task_status_checker.py b/rialto/runner/task_status_checker.py new file mode 100644 index 0000000..c198455 --- /dev/null +++ b/rialto/runner/task_status_checker.py @@ -0,0 +1,58 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__all__ = ["TaskStatusChecker"] + +from loguru import logger + +from rialto.runner.data_checker import DataChecker +from rialto.runner.task_registry import PipelineTask + + +class TaskStatusChecker: + """Handles completion and dependency checks for pipeline tasks.""" + + def __init__(self, checker: DataChecker): + self.checker = checker + + def check_completion(self, pipeline: PipelineTask) -> None: + """ + Check if pipeline is complete by checking if target data exists for partition date + + :param pipeline: Pipeline object for which to check completion + + :return: None, updates self.completion attribute + """ + pipeline.completion = self.checker.check_date(pipeline.target, pipeline.partition_date) + logger.info( + f"Job {pipeline.op} completion status for partition date " + f"{pipeline.partition_date}: {pipeline.completion}" + ) + + def check_pipeline_dependencies(self, pipeline: PipelineTask) -> None: + """ + Check if dependencies are complete by checking if data exists for each dependency in date range + + :param pipeline: Pipeline object for which to check dependencies + + :return: None, updates self.dependencies_complete attribute + """ + for dependency in pipeline.dependencies: + dependency.complete = self.checker.check_range( + dependency.table, dependency.date_from, dependency.date_until + ) + logger.info( + f"Dependency {dependency.table.get_table_path()} completion status for date range " + f"{dependency.date_from} - {dependency.date_until}: {dependency.complete}" + ) + pipeline.dependencies_complete = all([dependency.complete for dependency in pipeline.dependencies]) diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py index 5af1723..21a231c 100644 --- a/rialto/runner/utils.py +++ b/rialto/runner/utils.py @@ -12,19 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_module", "table_exists", "get_partitions", "init_tools", "find_dependency"] +__all__ = ["load_module", "init_tools", "find_dependency"] -from datetime import date from importlib import import_module -from typing import List, Tuple +from typing import Tuple from pyspark.sql import SparkSession -from rialto.common import DataReader from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner.config_loader import ModuleConfig, PipelineConfig -from rialto.runner.table import Table from rialto.runner.transformation import Transformation @@ -40,32 +37,6 @@ def load_module(cfg: ModuleConfig) -> Transformation: return class_obj() -def table_exists(spark: SparkSession, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return spark.catalog.tableExists(table) - - -def get_partitions(reader: DataReader, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - - def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: """ Initialize metadata manager and feature loader diff --git a/rialto/runner/writer.py b/rialto/runner/writer.py index bc147fd..185fafa 100644 --- a/rialto/runner/writer.py +++ b/rialto/runner/writer.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["Writer"] +__all__ = ["DatabricksWriter", "Writer"] +from abc import ABC, abstractmethod from datetime import date -from typing import List +from typing import Any, List import pyspark.sql.functions as F from loguru import logger @@ -24,9 +25,25 @@ from rialto.runner.table import Table -class Writer: +class Writer(ABC): """Supporting class for runner""" + @abstractmethod + def write(self, df: DataFrame, info_date: date, table: Table) -> None: + """ + Write dataframe to storage + + :param df: dataframe to write + :param info_date: date to partition + :param table: path to write to + :return: None + """ + pass + + +class DatabricksWriter(Writer): + """Supporting class for runner, Databricks write operations""" + def __init__(self, spark: SparkSession, merge_schema=False): self.spark = spark self.merge_schema = merge_schema @@ -61,50 +78,70 @@ def _align_schema(self, df: DataFrame, existing_columns: List) -> DataFrame: :return: dataframe with aligned schema """ if existing_columns is not None: + missing = [c for c in existing_columns if c not in df.columns] + if missing: + raise ValueError(f"DataFrame is missing columns present in existing table: {missing}") return df.select( - *[F.col(c) for c in existing_columns if c in df.columns], + *[F.col(c) for c in existing_columns], *[F.col(c) for c in df.columns if c not in existing_columns], ) return df - def _process(self, df: DataFrame, info_date: date, table: Table) -> DataFrame: - df = df.withColumn(table.partition, F.lit(info_date)) + def _process(self, df: DataFrame, partition_date: date, table: Table) -> DataFrame: + df = df.withColumn(table.partition, F.lit(partition_date)) df = self._align_schema(df, self._get_existing_columns(table)) return df - def _get_replace_condition(self, df: DataFrame, partition_cols: List[str]) -> str: - row = df.select(*partition_cols).distinct().collect() - if len(row) > 1: - raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}") - - parts = [] - for c in partition_cols: - val = row[0][c] - if val is None: - parts.append(f"{c} IS NULL") - elif isinstance(val, (int, float)): - parts.append(f"{c} = {val}") - else: - parts.append(f"{c} = '{val}'") - condition = " AND ".join(parts) - return condition - - def write(self, df: DataFrame, info_date: date, table: Table) -> None: + def _get_replace_expression(self, key: str, value: Any) -> str: + if value is None: + return f"{key} IS NULL" + elif isinstance(value, (int, float)): + return f"{key} = {value}" + else: + return f"{key} = '{value}'" + + def _get_replace_condition(self, df: DataFrame, target: Table, partition_date: date) -> str: + partition_cols = target.get_all_partition_columns() + + # only date column + if len(partition_cols) == 1: + return f"{partition_cols[0]} = '{partition_date.strftime('%Y-%m-%d')}'" + + # if target filters present for all partitions + elif target.filters and len(partition_cols) == len(target.filters): + parts = [] + for c in partition_cols: + parts.append(self._get_replace_expression(c, target.filters[c])) + condition = " AND ".join(parts) + return condition + # grab from dataframe + else: + row = df.select(*partition_cols).distinct().collect() + if len(row) > 1: + raise ValueError(f"Some of the partitions to write have more than 1 distinct value \n {row}") + + parts = [] + for c in partition_cols: + parts.append(self._get_replace_expression(c, row[0][c])) + condition = " AND ".join(parts) + return condition + + def write(self, df: DataFrame, partition_date: date, table: Table) -> None: """ Write dataframe to storage :param df: dataframe to write - :param info_date: date to partition + :param partition_date: date to partition :param table: path to write to :return: None """ self._create_schema(table) - df = self._process(df, info_date, table) + df = self._process(df, partition_date, table) - replace_where = self._get_replace_condition(df, table.get_all_partitions()) + replace_where = self._get_replace_condition(df, table, partition_date) df.write.format("delta").partitionBy(table.partition).mode("overwrite").option( "mergeSchema", "true" if self.merge_schema else "false" diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 4e527be..7ded329 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -37,6 +37,6 @@ def spark(request): return spark -@pytest.fixture(scope="function") +@pytest.fixture(scope="session") def basic_runner(spark): return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") diff --git a/tests/runner/test_overrides.py b/tests/runner/test_config_overrides.py similarity index 100% rename from tests/runner/test_overrides.py rename to tests/runner/test_config_overrides.py diff --git a/tests/runner/test_data_checker.py b/tests/runner/test_data_checker.py new file mode 100644 index 0000000..0ce98da --- /dev/null +++ b/tests/runner/test_data_checker.py @@ -0,0 +1,230 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import MagicMock + +import pytest +from pyspark.sql.types import DateType, IntegerType, StringType, StructField, StructType + +from rialto.common import TableReader +from rialto.runner.data_checker import DataChecker +from rialto.runner.table import Table + + +@pytest.fixture(scope="module") +def simple_dataframe(spark): + df = [ + ("A", date(2023, 3, 5)), + ("B", date(2023, 3, 12)), + ("C", date(2023, 3, 19)), + ] + schema = StructType([StructField("KEY", StringType(), True), StructField("DATE", DateType(), True)]) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def partitioned_dataframe(spark): + df = [ + ("W", 1, "A", date(2023, 3, 5)), + ("E", 1, "B", date(2023, 3, 5)), + ("R", 2, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 12)), + ("Y", 2, "A", date(2023, 3, 19)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.fixture(scope="module") +def new_insert_partitioned_dataframe(spark): + df = [ + ("E", 1, "B", date(2023, 3, 5)), + ("T", 1, "B", date(2023, 3, 5)), + ] + schema = StructType( + [ + StructField("VALUE", StringType(), True), + StructField("VERSION", IntegerType(), True), + StructField("TYPE", StringType(), True), + StructField("DATE", DateType(), True), + ] + ) + return spark.createDataFrame(df, schema=schema) + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 3, 12), True), + (date(2023, 3, 10), False), + (date(2023, 3, 19), True), + (date(2023, 3, 26), False), + ], +) +def test_check_date(mocker, spark, simple_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "start_date, end_date, expected", + [ + (date(2023, 3, 12), date(2023, 4, 12), True), + (date(2023, 3, 10), date(2023, 3, 11), False), + (date(2023, 3, 19), date(2023, 3, 19), True), + (date(2023, 3, 26), date(2023, 3, 29), False), + ], +) +def test_check_range(mocker, spark, simple_dataframe, start_date, end_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=simple_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_range(table, start_date, end_date) + assert result == expected + + +def test_check_range_no_table( + mocker, + spark, +): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=False) + + data_checker = DataChecker(TableReader(spark)) + table = Table(table_path="catalog.schema.simple_group", partition="DATE") + result = data_checker.check_date(table, date(2023, 3, 12)) + assert result is False + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), True), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_and_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters={"version": 1, "type": "A"}, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +@pytest.mark.parametrize( + "partition_date, expected", + [ + (date(2023, 2, 26), False), + (date(2023, 3, 5), False), + (date(2023, 3, 12), False), + (date(2023, 3, 19), False), + (date(2023, 3, 26), False), + ], +) +def test_check_date_secondary_partitions_no_filters(mocker, spark, partitioned_dataframe, partition_date, expected): + mocker.patch("rialto.common.table_reader.TableReader.table_exists", return_value=True) + mocker.patch("rialto.common.table_reader.TableReader._get_raw_data", return_value=partitioned_dataframe) + + data_checker = DataChecker(TableReader(spark)) + table = Table( + table_path="catalog.schema.simple_group", + partition="DATE", + secondary_partitions=["VERSION", "TYPE"], + filters=None, + ) + result = data_checker.check_date(table, partition_date) + assert result == expected + + +def test_check_written_with_no_filters_or_secondary_partitions(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE") + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={}, + ) + + +def test_check_written_with_filters(): + mock_reader = MagicMock() + mock_df = MagicMock() + mock_reader.get_table.return_value = mock_df + mock_df.count.return_value = 42 + + checker = DataChecker(mock_reader) + table = Table(table_path="dummy.table.path", partition="DATE", filters={"foo": "bar"}) + result = checker.check_written(table, date(2023, 3, 5), MagicMock()) + assert result == 42 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"foo": "bar"}, + ) + + +def test_check_written_with_secondary_partitions(mocker, new_insert_partitioned_dataframe): + # Setup + mock_reader = MagicMock() + mock_df = MagicMock() + mock_df.count.return_value = 7 + mock_reader.get_table.return_value = mock_df + + checker = DataChecker(mock_reader) + table = Table( + table_path="dummy.table.path", partition="DATE", filters=None, secondary_partitions=["VERSION", "TYPE"] + ) + result = checker.check_written(table, date(2023, 3, 5), new_insert_partitioned_dataframe) + assert result == 7 + mock_reader.get_table.assert_called_once_with( + "dummy.table.path", + date_column="DATE", + date_from=date(2023, 3, 5), + date_to=date(2023, 3, 5), + filters={"VERSION": 1, "TYPE": "B"}, + ) diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py index 73b61b8..916c751 100644 --- a/tests/runner/test_date_manager.py +++ b/tests/runner/test_date_manager.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime +from datetime import date, datetime import pytest -from rialto.runner.config_loader import IntervalConfig, ScheduleConfig +from rialto.runner.config_loader import IntervalConfig, RunnerConfig, ScheduleConfig from rialto.runner.date_manager import DateManager @@ -27,16 +27,16 @@ def test_str_to_date(): "units , value, res", [("days", 7, "2023-02-26"), ("weeks", 3, "2023-02-12"), ("months", 5, "2022-10-05"), ("years", 2, "2021-03-5")], ) -def test_date_from(units, value, res): +def test_date_subtract(units, value, res): rundate = DateManager.str_to_date("2023-03-05") - date_from = DateManager.date_subtract(run_date=rundate, units=units, value=value) + date_from = DateManager.date_subtract(input_date=rundate, units=units, value=value) assert date_from == DateManager.str_to_date(res) -def test_date_from_bad(): +def test_date_subtract_bad(): rundate = DateManager.str_to_date("2023-03-05") with pytest.raises(ValueError) as exception: - DateManager.date_subtract(run_date=rundate, units="random", value=1) + DateManager.date_subtract(input_date=rundate, units="random", value=1) assert str(exception.value) == "Unknown time unit random" @@ -58,114 +58,126 @@ def test_all_dates_reversed(): assert all_dates[1] == DateManager.str_to_date("2023-02-06") -def test_run_dates_weekly(): - cfg = ScheduleConfig(frequency="weekly", day=5) +def test_date_from(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + date_manager = DateManager(config=runner_cfg, run_date="2023-03-05") + assert date_manager.get_date_from() == DateManager.str_to_date("2022-12-05") - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2023-02-05"), - date_to=DateManager.str_to_date("2023-04-07"), - schedule=cfg, - ) - expected = [ - "2023-02-10", - "2023-02-17", - "2023-02-24", - "2023-03-03", - "2023-03-10", - "2023-03-17", - "2023-03-24", - "2023-03-31", - "2023-04-07", +def test_date_to(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + date_manager = DateManager(config=runner_cfg, run_date="2023-03-05") + assert date_manager.get_date_until() == DateManager.str_to_date("2023-03-05") + + +def test_run_dates_daily_no_shift(): + runner_cfg = RunnerConfig(watched_period_units="weeks", watched_period_value=1) + cfg = ScheduleConfig(frequency="daily") + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 5, 13), + date(2026, 5, 14), + date(2026, 5, 15), + date(2026, 5, 16), + date(2026, 5, 17), + date(2026, 5, 18), + date(2026, 5, 19), + date(2026, 5, 20), + ] + + expected_partition_dates = [ + date(2026, 5, 13), + date(2026, 5, 14), + date(2026, 5, 15), + date(2026, 5, 16), + date(2026, 5, 17), + date(2026, 5, 18), + date(2026, 5, 19), + date(2026, 5, 20), ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) -def test_run_dates_monthly(): - cfg = ScheduleConfig(frequency="monthly", day=5) +def test_run_dates_weekly_backwards_shift(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=1) + cfg = ScheduleConfig(frequency="weekly", day=5, info_date_shift=IntervalConfig(units="days", value=2)) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2022-08-05"), - date_to=DateManager.str_to_date("2023-04-07"), - schedule=cfg, - ) + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + + expected_execution_dates = [ + date(2026, 4, 24), + date(2026, 5, 1), + date(2026, 5, 8), + date(2026, 5, 15), + ] - expected = [ - "2022-08-05", - "2022-09-05", - "2022-10-05", - "2022-11-05", - "2022-12-05", - "2023-01-05", - "2023-02-05", - "2023-03-05", - "2023-04-05", + expected_partition_dates = [ + date(2026, 4, 22), + date(2026, 4, 29), + date(2026, 5, 6), + date(2026, 5, 13), ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) -def test_run_dates_daily(): - cfg = ScheduleConfig(frequency="daily") +def test_run_dates_monthly_with_forward_shift(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + cfg = ScheduleConfig(frequency="monthly", day=5, info_date_shift=IntervalConfig(units="days", value=-2)) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") - run_dates = DateManager.run_dates( - date_from=DateManager.str_to_date("2023-03-28"), - date_to=DateManager.str_to_date("2023-04-03"), - schedule=cfg, - ) + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) - expected = [ - "2023-03-28", - "2023-03-29", - "2023-03-30", - "2023-03-31", - "2023-04-01", - "2023-04-02", - "2023-04-03", + expected_execution_dates = [ + date(2026, 3, 5), + date(2026, 4, 5), + date(2026, 5, 5), ] - expected = [DateManager.str_to_date(d) for d in expected] - assert run_dates == expected + expected_partition_dates = [ + date(2026, 3, 7), + date(2026, 4, 7), + date(2026, 5, 7), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) -def test_run_dates_invalid(): - cfg = ScheduleConfig(frequency="random") - with pytest.raises(ValueError) as exception: - DateManager.run_dates( - date_from=DateManager.str_to_date("2023-03-28"), - date_to=DateManager.str_to_date("2023-04-03"), - schedule=cfg, - ) - assert str(exception.value) == "Unknown frequency random" +def test_run_dates_monthly_last(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + cfg = ScheduleConfig(frequency="monthly", day="last") + manager = DateManager(config=runner_cfg, run_date="2026-05-20") -@pytest.mark.parametrize( - "shift, res", - [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")], -) -def test_to_info_date(shift, res): - cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)]) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date(res) == info + exec, part = zip(*manager.get_execution_and_partition_dates(schedule=cfg)) + expected_execution_dates = [ + date(2026, 2, 28), + date(2026, 3, 31), + date(2026, 4, 30), + ] -@pytest.mark.parametrize( - "unit, result", - [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")], -) -def test_info_date_shift_units(unit, result): - cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)]) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date(result) == info + expected_partition_dates = [ + date(2026, 2, 28), + date(2026, 3, 31), + date(2026, 4, 30), + ] + assert expected_execution_dates == list(exec) + assert expected_partition_dates == list(part) -def test_info_date_shift_combined(): - cfg = ScheduleConfig( - frequency="daily", - info_date_shift=[IntervalConfig(units="months", value=3), IntervalConfig(units="days", value=4)], - ) - base = DateManager.str_to_date("2023-03-05") - info = DateManager.to_info_date(base, cfg) - assert DateManager.str_to_date("2022-12-01") == info +def test_invalid_days(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + weekly_cfg = ScheduleConfig(frequency="weekly", day=12) + monthly_cfg = ScheduleConfig(frequency="monthly", day=42) + manager = DateManager(config=runner_cfg, run_date="2026-05-20") + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=weekly_cfg) + + with pytest.raises(ValueError): + manager.get_execution_and_partition_dates(schedule=monthly_cfg) diff --git a/tests/runner/test_executor.py b/tests/runner/test_executor.py new file mode 100644 index 0000000..edc4b85 --- /dev/null +++ b/tests/runner/test_executor.py @@ -0,0 +1,49 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import MagicMock, patch + +from rialto.runner.executor import PipelineExecutor +from rialto.runner.task_registry import PipelineTask + + +def test_execute_calls_job_run_and_returns_df(): + mock_spark = MagicMock() + mock_reader = MagicMock() + mock_checker = MagicMock() + executor = PipelineExecutor(mock_spark, mock_reader, mock_checker) + + config = MagicMock() + config.module = "some.module" + + task = PipelineTask( + op="test_op", + execution_date=date(2026, 1, 1), + partition_date=date(2026, 1, 1), + config=config, + target=MagicMock(), + ) + + mock_job = MagicMock() + mock_df = MagicMock() + mock_job.run.return_value = mock_df + + with patch("rialto.runner.utils.load_module", return_value=mock_job) as load_module, patch( + "rialto.runner.utils.init_tools", return_value=(None, None) + ): + result = executor.execute(task) + + load_module.assert_called_once_with("some.module") + mock_job.run.assert_called_once() + assert result == mock_df diff --git a/tests/runner/test_table.py b/tests/runner/test_table.py index f1e4ead..e770743 100644 --- a/tests/runner/test_table.py +++ b/tests/runner/test_table.py @@ -1,3 +1,19 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from rialto.runner.config_loader import TargetConfig from rialto.runner.table import Table @@ -31,13 +47,13 @@ def test_table_path_init(): def test_table_secondary_partitions(): t = Table(catalog="cat", schema="sch", table="tab", partition="part", secondary_partitions=["sec1", "sec2"]) - assert t.get_all_partitions() == ["part", "sec1", "sec2"] + assert t.get_all_partition_columns() == ["part", "sec1", "sec2"] def test_table_get_partitions_only_main(): t = Table(catalog="cat", schema="sch", table="tab", partition="part") - assert t.get_all_partitions() == ["part"] + assert t.get_all_partition_columns() == ["part"] def test_table_prioritize_table_name(): @@ -48,3 +64,44 @@ def test_table_prioritize_table_name(): assert t.catalog == "cat" assert t.schema == "sch" assert t.table == "custom" + + +def test_from_target_config(): + tconfig = TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + secondary_partition_columns=["sec1", "sec2"], + custom_name=None, + rerun_filters={"col": "value"}, + ) + + pipeline_cfg = Mock() + pipeline_cfg.module.python_class = "TestClass" + pipeline_cfg.target = tconfig + + t = Table.from_target_config(pipeline_cfg) + + assert t.get_table_path() == "cat.sch.test_class" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "test_class" + assert t.get_all_partition_columns() == ["part", "sec1", "sec2"] + assert t.filters == {"col": "value"} + + +def test_from_dependency_config(): + dconfig = Mock() + dconfig.table = "cat.sch.tab" + dconfig.date_col = "date" + dconfig.filters = {"col": "value"} + + t = Table.from_dependency_config(dconfig) + + assert t.get_table_path() == "cat.sch.tab" + assert t.get_schema_path() == "cat.sch" + assert t.catalog == "cat" + assert t.schema == "sch" + assert t.table == "tab" + assert t.partition == "date" + assert t.filters == {"col": "value"} diff --git a/tests/runner/test_task_registry.py b/tests/runner/test_task_registry.py new file mode 100644 index 0000000..2d008a7 --- /dev/null +++ b/tests/runner/test_task_registry.py @@ -0,0 +1,173 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import date +from unittest.mock import patch + +import pytest + +from rialto.runner.config_loader import ( + DependencyConfig, + IntervalConfig, + ModuleConfig, + PipelineConfig, + RunnerConfig, + ScheduleConfig, + TargetConfig, +) +from rialto.runner.date_manager import DateManager +from rialto.runner.task_registry import PipelineDependency, PipelineTask, TaskRegistry + + +@pytest.fixture(scope="module") +def date_manager(): + runner_cfg = RunnerConfig(watched_period_units="months", watched_period_value=3) + return DateManager(runner_cfg, "2020-01-01") + + +@pytest.fixture +def pipeline_config_no_deps(): + return PipelineConfig( + name="test_pipeline", + module=ModuleConfig(python_module="some.module", python_class="TestClass"), + schedule=ScheduleConfig(frequency="monthly", day=1), + target=TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + secondary_partition_columns=["sec1", "sec2"], + custom_name=None, + rerun_filters={"col": "value"}, + ), + ) + + +@pytest.fixture +def pipeline_config_with_deps(): + return PipelineConfig( + name="test_pipeline_with_deps", + module=ModuleConfig(python_module="some.module", python_class="TestClass"), + schedule=ScheduleConfig(frequency="monthly", day=1), + target=TargetConfig( + target_schema="cat.sch", + target_partition_column="part", + ), + dependencies=[ + DependencyConfig( + table="cat.sch.dep_table", + date_col="part", + interval=IntervalConfig(units="months", value=1), + ) + ], + ) + + +def test_registry_initializes_empty(spark, date_manager): + registry = TaskRegistry(spark, date_manager) + assert list(registry) == [] + + +def test_add_task_no_dependencies(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_no_deps, + ) + + tasks = list(registry) + assert len(tasks) == 1 + + task = tasks[0] + assert isinstance(task, PipelineTask) + assert task.op == "test_pipeline" + assert task.execution_date == date(2020, 1, 1) + assert task.partition_date == date(2019, 12, 31) + assert task.config is pipeline_config_no_deps + assert task.dependencies == [] + assert task.completion is False + assert task.dependencies_complete is False + + +def test_add_task_target_table(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_no_deps, + ) + + task = list(registry)[0] + assert task.target.catalog == "cat" + assert task.target.schema == "sch" + assert task.target.partition == "part" + assert task.target.secondary_partitions == ["sec1", "sec2"] + assert task.target.filters == {"col": "value"} + + +def test_add_task_with_dependencies(spark, date_manager, pipeline_config_with_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task( + name="test_pipeline_with_deps", + execution_date=date(2020, 1, 1), + partition_date=date(2019, 12, 31), + config=pipeline_config_with_deps, + ) + + task = list(registry)[0] + assert len(task.dependencies) == 1 + + dep = task.dependencies[0] + assert isinstance(dep, PipelineDependency) + assert dep.table.get_table_path() == "cat.sch.dep_table" + assert dep.date_until == date(2020, 1, 1) + assert dep.date_from == date(2019, 12, 1) # 1 month subtracted + assert dep.complete is False + + +def test_add_multiple_tasks(spark, date_manager, pipeline_config_no_deps, pipeline_config_with_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("pipeline_a", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.add_task("pipeline_b", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_with_deps) + + tasks = list(registry) + assert len(tasks) == 2 + assert tasks[0].op == "pipeline_a" + assert tasks[1].op == "pipeline_b" + + +def test_iteration(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("p1", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.add_task("p2", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + + names = [task.op for task in registry] + assert names == ["p1", "p2"] + + +def test_log_status_contains_task_info(spark, date_manager, pipeline_config_no_deps): + registry = TaskRegistry(spark, date_manager) + registry.add_task("test_pipeline", date(2020, 1, 1), date(2019, 12, 31), pipeline_config_no_deps) + registry.tasks[0].completion = True + + with patch("rialto.runner.task_registry.logger") as mock_logger: + registry.log_status() + + mock_logger.info.assert_called_once() + logged_output = mock_logger.info.call_args[0][0] + + assert "test_pipeline" in logged_output + assert "2019-12-31" in logged_output + assert "✔" in logged_output + assert "✘" in logged_output diff --git a/tests/runner/test_task_status_checker.py b/tests/runner/test_task_status_checker.py new file mode 100644 index 0000000..d8772c1 --- /dev/null +++ b/tests/runner/test_task_status_checker.py @@ -0,0 +1,132 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date +from unittest.mock import Mock, patch + +import pytest + +from rialto.runner.table import Table +from rialto.runner.task_registry import PipelineDependency +from rialto.runner.task_status_checker import TaskStatusChecker + + +def make_task(op="my_pipeline", partition_date=date(2020, 1, 1), dependencies=None): + task = Mock() + task.op = op + task.partition_date = partition_date + task.target = Table(schema_path="cat.sch", class_name="TestClass", partition="part") + task.dependencies = dependencies or [] + task.completion = False + task.dependencies_complete = False + return task + + +def make_dependency(table_path="cat.sch.dep_table", date_from=date(2019, 12, 1), date_until=date(2020, 1, 1)): + table = Table(table_path=table_path, partition="part") + return PipelineDependency(table=table, date_from=date_from, date_until=date_until) + + +@pytest.fixture +def mock_data_checker(): + return Mock() # Mock of DataChecker + + +@pytest.fixture +def status_checker(mock_data_checker): + return TaskStatusChecker(checker=mock_data_checker) + + +def test_check_completion_sets_true_when_data_exists(status_checker, mock_data_checker): + mock_data_checker.check_date.return_value = True + task = make_task() + + status_checker.check_completion(task) + + assert task.completion is True + mock_data_checker.check_date.assert_called_once_with(task.target, task.partition_date) + + +def test_check_completion_sets_false_when_no_data(status_checker, mock_data_checker): + mock_data_checker.check_date.return_value = False + task = make_task() + + status_checker.check_completion(task) + + assert task.completion is False + + +# --- check_pipeline_dependencies --- + + +def test_check_pipeline_dependencies_no_deps(status_checker, mock_data_checker): + task = make_task(dependencies=[]) + + status_checker.check_pipeline_dependencies(task) + + assert task.dependencies_complete is True # all([]) == True + mock_data_checker.check_range.assert_not_called() + + +def test_check_pipeline_dependencies_all_complete(status_checker, mock_data_checker): + mock_data_checker.check_range.return_value = True + dep1 = make_dependency("cat.sch.table_a") + dep2 = make_dependency("cat.sch.table_b") + task = make_task(dependencies=[dep1, dep2]) + + status_checker.check_pipeline_dependencies(task) + + assert dep1.complete is True + assert dep2.complete is True + assert task.dependencies_complete is True + assert mock_data_checker.check_range.call_count == 2 + + +def test_check_pipeline_dependencies_one_incomplete(status_checker, mock_data_checker): + mock_data_checker.check_range.side_effect = [True, False] + dep1 = make_dependency("cat.sch.table_a") + dep2 = make_dependency("cat.sch.table_b") + task = make_task(dependencies=[dep1, dep2]) + + status_checker.check_pipeline_dependencies(task) + + assert dep1.complete is True + assert dep2.complete is False + assert task.dependencies_complete is False + + +def test_check_pipeline_dependencies_passes_correct_dates(status_checker, mock_data_checker): + mock_data_checker.check_range.return_value = True + dep = make_dependency(date_from=date(2019, 10, 1), date_until=date(2020, 1, 1)) + task = make_task(dependencies=[dep]) + + status_checker.check_pipeline_dependencies(task) + + mock_data_checker.check_range.assert_called_once_with(dep.table, date(2019, 10, 1), date(2020, 1, 1)) + + +def test_check_completion_logs_status(): + checker = Mock() + checker.check_date.return_value = True + task = make_task(op="logged_pipeline", partition_date=date(2020, 3, 1)) + status_checker = TaskStatusChecker(checker=checker) + + with patch("rialto.runner.task_status_checker.logger") as mock_logger: + status_checker.check_completion(task) + + mock_logger.info.assert_called_once() + log_msg = mock_logger.info.call_args[0][0] + assert "logged_pipeline" in log_msg + assert "2020-03-01" in log_msg + assert "True" in log_msg diff --git a/tests/runner/test_utils.py b/tests/runner/test_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/runner/test_writer.py b/tests/runner/test_writer.py index 5ec9ec6..a58fe58 100644 --- a/tests/runner/test_writer.py +++ b/tests/runner/test_writer.py @@ -12,47 +12,160 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import date +from datetime import date, datetime +from unittest.mock import MagicMock, Mock, patch import pytest +from pyspark.sql import Row -from rialto.runner.writer import Writer +from rialto.runner.table import Table +from rialto.runner.writer import DatabricksWriter @pytest.fixture -def sample_multi_partition(spark): - df = spark.createDataFrame( - [ - ("REGION_A", 3, date(2023, 1, 1), 100), - ("REGION_A", 3, date(2023, 1, 1), 300), - ], - schema="region string, version int, info_date date, value int", - ) - return df +def writer(spark): + return DatabricksWriter(spark, merge_schema=False) + + +@pytest.fixture +def writer_merge(spark): + return DatabricksWriter(spark, merge_schema=True) + + +@pytest.fixture +def simple_table(): + return Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date") @pytest.fixture -def sample_multi_partition_non_unique(spark): +def table_with_secondary(simple_table): + simple_table.secondary_partitions = ["region"] + return simple_table + + +# --- _align_schema --- + + +def test_align_schema_no_existing_columns_returns_df_unchanged(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2)]) + result = writer._align_schema(df, None) + assert result.columns == df.columns + + +def test_align_schema_reorders_to_existing_columns(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2, c=3)]) + result = writer._align_schema(df, ["c", "a", "b"]) + assert result.columns == ["c", "a", "b"] + + +def test_align_schema_new_columns_appended_after_existing(spark, writer): + df = spark.createDataFrame([Row(a=1, b=2, new_col=3)]) + result = writer._align_schema(df, ["a", "b"]) + assert result.columns == ["a", "b", "new_col"] + + +def test_align_schema_missing_existing_columns_raises_value_error(spark, writer): + # column "gone" is in existing but not in df — should raise ValueError + df = spark.createDataFrame([Row(a=1, b=2)]) + with pytest.raises(ValueError): + writer._align_schema(df, ["gone", "a", "b"]) + + +# --- _get_replace_condition --- + + +def test_get_replace_condition_string_value(spark, writer): + df = spark.createDataFrame([Row(info_date=date(2020, 1, 1))]) + target = Table(schema_path="default.test_schema", class_name="MyTable", partition="info_date") + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "info_date = '2020-01-01'" + + +def test_get_replace_condition_second_value_no_filters(spark, writer): + df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)]) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + ) + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "information_date = '2020-01-01' AND region = 1" + + +def test_get_replace_condition_second_value_with_filters(spark, writer): + df = spark.createDataFrame([Row(information_date="2020-01-01", region=1)]) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], + filters={"region": 1}, + ) + condition = writer._get_replace_condition(df, target, datetime(2020, 1, 1)) + assert condition == "information_date = '2020-01-01' AND region = 1" + + +def test_get_replace_condition_raises_on_multiple_distinct_values(spark, writer): df = spark.createDataFrame( - [ - ("REGION_A", 1, date(2023, 1, 1), 100), - ("REGION_A", 2, date(2023, 1, 1), 300), - ], - schema="region string, version int, info_date date, value int", + [Row(information_date="2020-01-01", region=1), Row(information_date="2020-01-01", region=2)] + ) + target = Table( + schema_path="default.test_schema", + class_name="MyTable", + partition="information_date", + secondary_partitions=["region"], ) - return df + with pytest.raises(ValueError, match="more than 1 distinct value"): + writer._get_replace_condition(df, target, datetime(2020, 1, 1)) -def test_replace_condition(sample_multi_partition): - writer = Writer(spark=None) - condition = writer._get_replace_condition(sample_multi_partition, partition_cols=["region", "version", "info_date"]) - expected_condition = "region = 'REGION_A' AND version = 3 AND info_date = '2023-01-01'" - assert condition == expected_condition +# --- _process --- -def test_replace_condition_non_unique(sample_multi_partition_non_unique): - writer = Writer(spark=None) - with pytest.raises(ValueError): - writer._get_replace_condition( - sample_multi_partition_non_unique, partition_cols=["region", "version", "info_date"] - ) +def test_process_adds_partition_column(spark, writer, simple_table): + df = spark.createDataFrame([Row(a=1)]) + with patch.object(writer, "_get_existing_columns", return_value=None): + result = writer._process(df, date(2020, 1, 1), simple_table) + assert "info_date" in result.columns + assert result.collect()[0]["info_date"] == date(2020, 1, 1) + + +# --- write (integration of internal steps) --- + + +def test_write_calls_create_schema(spark, writer, simple_table): + with patch.object(writer, "_create_schema") as mock_create, patch.object(writer, "_process"), patch.object( + writer, "_get_replace_condition" + ): + df = Mock() + df.write = Mock() + writer.write(df, Mock(), simple_table) + + mock_create.assert_called_once_with(simple_table) + + +def test_write_merge_schema_option(spark, writer_merge, simple_table): + df = MagicMock() + df.write = MagicMock() + + with patch.object(writer_merge, "_create_schema"), patch.object( + writer_merge, "_process", return_value=df + ), patch.object(writer_merge, "_get_replace_condition"): + writer_merge.write(df, date(2020, 1, 1), simple_table) + + option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list + assert any(call.args == ("mergeSchema", "true") for call in option_calls) + + +def test_write_not_merge_schema_option(spark, writer, simple_table): + df = MagicMock() + df.write = MagicMock() + + with patch.object(writer, "_create_schema"), patch.object(writer, "_process", return_value=df), patch.object( + writer, "_get_replace_condition" + ): + writer.write(df, date(2020, 1, 1), simple_table) + + option_calls = df.write.format.return_value.partitionBy.return_value.mode.return_value.option.call_args_list + assert any(call.args == ("mergeSchema", "false") for call in option_calls)