diff --git a/.gitignore b/.gitignore index 5730c00..ca58e3d 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,6 @@ test_project __dsgrid_scratch__ journal*.json5 dev_project -dev_project.json5 \ No newline at end of file +dev_project.json5 +*DS_Store* +equinor* \ No newline at end of file diff --git a/src/stride/cli/stride.py b/src/stride/cli/stride.py index ea2aac6..c3966b4 100644 --- a/src/stride/cli/stride.py +++ b/src/stride/cli/stride.py @@ -3,7 +3,7 @@ from typing import Any, Callable import rich_click as click -from chronify.exceptions import ChronifyExceptionBase +from chronify.exceptions import ChronifyExceptionBase, InvalidParameter from chronify.loggers import setup_logging from dsgrid.cli.common import path_callback from dsgrid.exceptions import DSGBaseException @@ -11,7 +11,7 @@ from stride import Project from stride.config import CACHED_PROJECTS_UPPER_BOUND -from stride.models import CalculatedTableOverride +from stride.models import CalculatedTableOverride, CustomDemandComponent from stride.project import list_valid_countries, list_valid_model_years, list_valid_weather_years from stride.ui.palette_utils import list_user_palettes, set_palette_priority from stride.dataset_download import ( @@ -191,6 +191,14 @@ def create_project( ) if res[1] != 0: ctx.exit(res[1]) + project = res[0] + if project is not None: + try: + from stride.ui.project_manager import add_recent_project + + add_recent_project(project.path, project.config.project_id) + except Exception: + logger.exception("Could not add to recent projects") _export_ep_epilog = """ @@ -1240,6 +1248,139 @@ def refresh_palette(ctx: click.Context, project_path: Path) -> None: print("\nPalette colors refreshed and saved!") +@click.group(name="custom-demand") +def custom_demand() -> None: + """Custom demand component commands""" + + +_custom_demand_add_epilog = """ +Examples:\n +Add a flat-profile data center component:\n +$ stride custom-demand add my_project --name data_centers --sector "Data Centers" --data-file data/dc_annual.csv\n +\n +Add a heat pump component using the residential load shape:\n +$ stride custom-demand add my_project --name heat_pumps --sector "Heat Pumps" --data-file hp.csv --load-profile "sector:Residential" --metric heating\n +""" + + +@click.command(name="add", epilog=_custom_demand_add_epilog) +@click.argument("project-path", type=click.Path(exists=True), callback=path_callback) +@click.option("--name", type=str, required=True, help="Unique identifier (e.g., 'heat_pumps')") +@click.option("--sector", type=str, required=True, help="Sector label for UI (e.g., 'Heat Pumps')") +@click.option( + "--data-file", + type=click.Path(exists=True), + required=True, + help="CSV/Parquet with model_year and value columns", + callback=path_callback, +) +@click.option( + "--load-profile", + type=str, + default="flat", + show_default=True, + help="Profile: 'flat', 'sector:', 'enduse:', or path to 8760 CSV", +) +@click.option( + "--metric", + type=str, + default="other", + show_default=True, + help="End-use/metric label (e.g., 'heating', 'cooling', 'other')", +) +@click.pass_context +def add_custom_demand( + ctx: click.Context, + project_path: Path, + name: str, + sector: str, + data_file: Path, + load_profile: str, + metric: str, +) -> None: + """Add a custom demand component and recompute the energy projection.""" + res = handle_stride_exception( + ctx, + _add_custom_demand, + project_path, + name, + sector, + data_file, + load_profile, + metric, + ) + if res[1] != 0: + ctx.exit(res[1]) + + +def _add_custom_demand( + project_path: Path, + name: str, + sector: str, + data_file: Path, + load_profile: str, + metric: str, +) -> None: + project = Project.load(project_path) + # Check for duplicate name + existing = {c.name for c in project.config.custom_demand_components} + if name in existing: + msg = f"Custom demand component '{name}' already exists. Remove it first." + raise InvalidParameter(msg) + + component = CustomDemandComponent( + name=name, + sector=sector, + data_file=data_file.resolve(), + load_profile=load_profile, + metric=metric, + ) + project.config.custom_demand_components.append(component) + project.persist() + project.compute_energy_projection() + print(f"Added custom demand component '{name}' and recomputed energy projection.") + + +@click.command(name="list") +@click.argument("project-path", type=click.Path(exists=True), callback=path_callback) +@click.pass_context +def list_custom_demand(ctx: click.Context, project_path: Path) -> None: + """List custom demand components in the project.""" + project = safe_get_project_from_context(ctx, project_path, read_only=True) + components = project.config.custom_demand_components + if not components: + print("No custom demand components configured.") + return + print(f"Custom demand components ({len(components)}):") + for c in components: + print(f" {c.name}: sector={c.sector!r}, profile={c.load_profile!r}, " + f"metric={c.metric!r}, data_file={c.data_file}") + + +@click.command(name="remove") +@click.argument("project-path", type=click.Path(exists=True), callback=path_callback) +@click.option("--name", type=str, required=True, help="Name of the component to remove") +@click.pass_context +def remove_custom_demand(ctx: click.Context, project_path: Path, name: str) -> None: + """Remove a custom demand component and recompute the energy projection.""" + res = handle_stride_exception(ctx, _remove_custom_demand, project_path, name) + if res[1] != 0: + ctx.exit(res[1]) + + +def _remove_custom_demand(project_path: Path, name: str) -> None: + project = Project.load(project_path) + components = project.config.custom_demand_components + original_len = len(components) + project.config.custom_demand_components = [c for c in components if c.name != name] + if len(project.config.custom_demand_components) == original_len: + msg = f"Custom demand component '{name}' not found." + raise InvalidParameter(msg) + project.persist() + project.compute_energy_projection() + print(f"Removed custom demand component '{name}' and recomputed energy projection.") + + def handle_stride_exception( ctx: click.Context, func: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: @@ -1277,6 +1418,7 @@ def safe_get_project_from_context( cli.add_command(calculated_tables) cli.add_command(palette) cli.add_command(view) +cli.add_command(custom_demand) projects.add_command(init_project) projects.add_command(create_project) projects.add_command(export_energy_projection) @@ -1300,3 +1442,6 @@ def safe_get_project_from_context( palette.add_command(set_priority) palette.add_command(get_priority) palette.add_command(refresh_palette) +custom_demand.add_command(add_custom_demand) +custom_demand.add_command(list_custom_demand) +custom_demand.add_command(remove_custom_demand) diff --git a/src/stride/dbt/models/energy_projection_com_ind_tra_load_shapes.sql b/src/stride/dbt/models/energy_projection_com_ind_tra_load_shapes.sql index c55877a..32e6af9 100644 --- a/src/stride/dbt/models/energy_projection_com_ind_tra_load_shapes.sql +++ b/src/stride/dbt/models/energy_projection_com_ind_tra_load_shapes.sql @@ -61,12 +61,14 @@ ev_annual_energy AS ( stride_annual_energy AS ( -- Combine base energy intensity projections with optional EV projections -- If use_ev_projection is true, replace Transportation + Road with EV-based calculation + -- Tag each row with energy_source so we can assign distinct metrics later SELECT geography, model_year, sector, subsector, - stride_annual_total + stride_annual_total, + 'base' AS energy_source FROM stride_annual_energy_base WHERE NOT (sector = 'Transportation' AND subsector = 'Road' AND {{ var("use_ev_projection", False) }}) @@ -77,7 +79,8 @@ stride_annual_energy AS ( model_year, sector, subsector, - stride_annual_total + stride_annual_total, + 'ev' AS energy_source FROM ev_annual_energy ), @@ -86,10 +89,12 @@ scaling_factors AS ( -- This scales the temperature-adjusted load shapes to match STRIDE totals -- Same scaling factor applies to all enduses within a sector -- Note: Load shapes are at sector level, so we aggregate subsectors + -- When EV is enabled, Transportation gets two scaling factors (base vs ev) SELECT ls.geography, ls.model_year, ls.sector, + stride.energy_source, CASE WHEN ls.load_shape_annual_total > 0 THEN SUM(stride.stride_annual_total) / ls.load_shape_annual_total @@ -100,10 +105,11 @@ scaling_factors AS ( ON ls.geography = stride.geography AND ls.model_year = stride.model_year AND ls.sector = stride.sector - GROUP BY ls.geography, ls.model_year, ls.sector, ls.load_shape_annual_total + GROUP BY ls.geography, ls.model_year, ls.sector, + ls.load_shape_annual_total, stride.energy_source ) --- Apply scaling factors to create final hourly energy projections +-- Non-EV rows: use base scaling factor, keep original end-use metric SELECT ls.timestamp, ls.model_year, @@ -116,3 +122,25 @@ JOIN scaling_factors sf ON ls.geography = sf.geography AND ls.model_year = sf.model_year AND ls.sector = sf.sector +WHERE sf.energy_source = 'base' + +UNION ALL + +-- EV rows: use EV scaling factor, single 'ev_charging' metric per hour. +-- Aggregate load shape across enduses to avoid per-enduse duplication. +-- Only applies to Transportation sector when EV projection is enabled. +SELECT + ls.timestamp, + ls.model_year, + ls.geography, + ls.sector, + 'ev_charging' AS metric, + SUM(ls.adjusted_value) * sf.scaling_factor AS value +FROM load_shapes_filtered ls +JOIN scaling_factors sf + ON ls.geography = sf.geography + AND ls.model_year = sf.model_year + AND ls.sector = sf.sector +WHERE sf.energy_source = 'ev' + AND ls.sector = 'Transportation' +GROUP BY ls.timestamp, ls.model_year, ls.geography, ls.sector, sf.scaling_factor diff --git a/src/stride/dsgrid_integration.py b/src/stride/dsgrid_integration.py index 10fa926..ead5923 100644 --- a/src/stride/dsgrid_integration.py +++ b/src/stride/dsgrid_integration.py @@ -17,6 +17,7 @@ ) from dsgrid.registry.bulk_register import bulk_register from dsgrid.registry.common import DataStoreType, DatabaseConnection +from dsgrid.exceptions import DSGValueNotRegistered from dsgrid.registry.registry_manager import RegistryManager from loguru import logger @@ -108,9 +109,25 @@ def make_mapped_datasets( ) continue + # Use scenario-specific dataset if one was registered (e.g., for overrides). + # dimension_mappings.json5 only contains baseline__ dataset IDs, so we must + # check whether a scenario-specific dataset exists and substitute it. + scenario_dataset_id = f"{scenario}__{table_name}" + try: + mgr.dataset_manager.get_by_id(scenario_dataset_id) + effective_mapping = dict(mapping) + effective_mapping["dataset_id"] = scenario_dataset_id + logger.info( + "Using scenario-specific dataset {} instead of {}", + scenario_dataset_id, + dataset_id, + ) + except DSGValueNotRegistered: + effective_mapping = mapping + _process_dataset_mapping( con=con, - mapping=mapping, + mapping=effective_mapping, mappings_dir=mappings_dir, mgr=mgr, scenario=scenario, diff --git a/src/stride/models.py b/src/stride/models.py index 86eabe2..2738e8b 100644 --- a/src/stride/models.py +++ b/src/stride/models.py @@ -22,6 +22,45 @@ class ProjectionSliceType(StrEnum): HEAT_PUMPS = "heat_pumps" +class CustomDemandComponent(DSGBaseModel): # type: ignore + """Defines an additive custom demand component. + + A custom demand component is a user-defined electricity load (e.g., heat pumps, + data centers) that is injected into the energy projection after dbt computation. + Annual MWh values are distributed into 8760 hourly rows using the specified + load profile. + """ + + name: str = Field( + description="Unique identifier for this component (e.g., 'heat_pumps')" + ) + sector: str = Field( + description="Sector label for UI grouping (e.g., 'Heat Pumps')" + ) + data_file: Path = Field( + description="Path to CSV/Parquet with model_year and value (annual MWh) columns" + ) + load_profile: str = Field( + default="flat", + description=( + "How to distribute annual energy into hours. Options: " + "'flat', 'sector:', 'enduse:', or a file path to an 8760 CSV" + ), + ) + metric: str = Field( + default="other", + description="End-use/metric label (e.g., 'heating', 'cooling', 'other')", + ) + + @field_validator("name") + @classmethod + def check_name(cls, name: str) -> str: + if not name.isidentifier(): + msg = f"Component name must be a valid Python identifier: {name!r}" + raise ValueError(msg) + return name + + class Scenario(DSGBaseModel): # type: ignore """Allows the user to add custom tables to compare against the defaults.""" @@ -74,6 +113,13 @@ class Scenario(DSGBaseModel): # type: ignore default=None, description="Optional path to a user-provided vehicle_per_capita_regressions table", ) + custom_demand_overrides: dict[str, Path] = Field( + default={}, + description=( + "Per-scenario overrides for custom demand components. " + "Keys are component names, values are paths to alternative data files." + ), + ) @field_validator("name") @classmethod @@ -164,6 +210,10 @@ class ProjectConfig(DSGBaseModel): # type: ignore default=[], description="Calculated tables to override", ) + custom_demand_components: list[CustomDemandComponent] = Field( + default=[], + description="Additive custom demand components (e.g., heat pumps, data centers)", + ) color_palette: dict[str, dict[str, str]] = Field( default={"scenarios": {}, "model_years": {}, "sectors": {}, "end_uses": {}}, description="Color palette organized into scenarios, model_years, sectors, and end_uses categories. Each category maps labels to hex/rgb color strings for the UI.", @@ -175,7 +225,7 @@ def from_file(cls, filename: Path | str) -> Self: config = super().from_file(path) for scenario in config.scenarios: for field in Scenario.model_fields: - if field in ("name", "use_ev_projection"): + if field in ("name", "use_ev_projection", "custom_demand_overrides"): continue val = getattr(scenario, field) if val is not None and not val.is_absolute(): @@ -187,6 +237,16 @@ def from_file(cls, filename: Path | str) -> Self: f"does not exist" ) raise InvalidParameter(msg) + for key, val in scenario.custom_demand_overrides.items(): + if not val.is_absolute(): + val = (path.parent / val).resolve() + scenario.custom_demand_overrides[key] = val + if not val.exists(): + msg = ( + f"Scenario={scenario.name} custom_demand_override={key} " + f"filename={val} does not exist" + ) + raise InvalidParameter(msg) for table in config.calculated_table_overrides: if table.filename is not None and not table.filename.is_absolute(): table.filename = path.parent / table.filename @@ -196,6 +256,15 @@ def from_file(cls, filename: Path | str) -> Self: f"filename={table.filename} does not exist" ) raise InvalidParameter(msg) + for component in config.custom_demand_components: + if not component.data_file.is_absolute(): + component.data_file = (path.parent / component.data_file).resolve() + if not component.data_file.exists(): + msg = ( + f"Custom demand component={component.name} " + f"data_file={component.data_file} does not exist" + ) + raise InvalidParameter(msg) return config # type: ignore def list_model_years(self) -> list[int]: diff --git a/src/stride/project.py b/src/stride/project.py index 3a82736..eb9f629 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -27,6 +27,7 @@ from stride.io import create_table_from_file, export_table from stride.models import ( CalculatedTableOverride, + CustomDemandComponent, ProjectConfig, Scenario, ) @@ -514,7 +515,11 @@ def list_calculated_tables(self) -> list[str]: @staticmethod def list_data_tables() -> list[str]: """List the data tables available in any project.""" - return [x for x in Scenario.model_fields if x not in ("name", "use_ev_projection")] + return [ + x + for x in Scenario.model_fields + if x not in ("name", "use_ev_projection", "custom_demand_overrides") + ] def persist(self) -> None: """Persist the project config to the project directory.""" @@ -631,6 +636,29 @@ def compute_energy_projection(self, use_table_overrides: bool = True) -> None: multiplier_stats[5], ) + # Inject custom demand components (after dbt, before copying to main) + # dbt outputs energy_projection as a view; materialize it as a table + # so we can INSERT custom demand rows into it. + if self._config.custom_demand_components: + self._con.sql( + f"CREATE TABLE {scenario.name}.__ep_tmp AS " + f"SELECT * FROM {scenario.name}.energy_projection" + ) + self._con.sql( + f"DROP VIEW {scenario.name}.energy_projection" + ) + self._con.sql( + f"ALTER TABLE {scenario.name}.__ep_tmp " + f"RENAME TO energy_projection" + ) + injected = self.inject_custom_demand_components(scenario) + if injected > 0: + logger.info( + "Injected {} custom demand component rows for scenario '{}'", + injected, + scenario.name, + ) + columns = "timestamp, model_year, scenario, sector, geography, metric, value" if i == 0: query = f""" @@ -653,6 +681,418 @@ def compute_energy_projection(self, use_table_overrides: bool = True) -> None: ) self._con.commit() + def inject_custom_demand_components(self, scenario: Scenario) -> int: + """Inject custom demand component rows into {scenario}.energy_projection. + + For each custom demand component: + 1. Load annual MWh data from CSV/Parquet + 2. Resolve the load profile (flat only for now; sector/enduse reference in CP3) + 3. Distribute annual energy into 8760 hours using the profile + 4. INSERT rows into {scenario}.energy_projection + + Uses defensive DELETE before insert to ensure idempotency. + + Returns the number of rows injected. + """ + total_injected = 0 + model_years = self._config.list_model_years() + model_years_str = ", ".join(str(y) for y in model_years) + + for component in self._config.custom_demand_components: + # Resolve data file: use scenario override if available + data_file = scenario.custom_demand_overrides.get( + component.name, component.data_file + ) + + # Load annual data into a staging table + staging_table = f"stride.custom__{component.name}__{scenario.name}__annual" + create_table_from_file( + self._con, staging_table, data_file, replace=True + ) + + # Validate required columns + columns = [ + col[0] + for col in self._con.sql(f"DESCRIBE {staging_table}").fetchall() + ] + if "model_year" not in columns or "value" not in columns: + msg = ( + f"Custom demand component '{component.name}' data file " + f"must have 'model_year' and 'value' columns. " + f"Found: {columns}" + ) + raise InvalidParameter(msg) + + # Check that data covers the required model years + available_years = { + row[0] + for row in self._con.sql( + f"SELECT DISTINCT model_year FROM {staging_table}" + ).fetchall() + } + missing_years = set(model_years) - available_years + if missing_years: + msg = ( + f"Custom demand component '{component.name}' data file " + f"is missing model years: {sorted(missing_years)}" + ) + raise InvalidParameter(msg) + + # Defensive DELETE: remove any existing rows for this custom component + self._con.sql( + f"DELETE FROM {scenario.name}.energy_projection " + f"WHERE sector = '{component.sector}'" + f" AND metric = '{component.metric}'" + ) + + # Generate and insert hourly rows using the appropriate profile + profile_sql = self._build_profile_sql( + component, scenario.name, model_years_str + ) + + insert_sql = f""" + INSERT INTO {scenario.name}.energy_projection + {profile_sql} + """ + self._con.sql(insert_sql) + + # Count injected rows + count = self._con.sql( + f"SELECT COUNT(*) FROM {scenario.name}.energy_projection " + f"WHERE sector = '{component.sector}'" + f" AND metric = '{component.metric}'" + ).fetchone()[0] + total_injected += count + + logger.info( + "Injected {} rows for custom component '{}' (sector='{}') in scenario '{}'", + count, + component.name, + component.sector, + scenario.name, + ) + + # Verify annual totals match input for years that were injected + for year in model_years: + expected = self._con.sql( + f"SELECT value FROM {staging_table} WHERE model_year = {year}" + ).fetchone()[0] + actual = self._con.sql( + f"SELECT SUM(value) FROM {scenario.name}.energy_projection " + f"WHERE sector = '{component.sector}'" + f" AND metric = '{component.metric}'" + f" AND model_year = {year}" + ).fetchone()[0] + if actual is None: + logger.warning( + "No hourly timestamps found for year {} — " + "custom component '{}' was not injected for this year", + year, + component.name, + ) + elif abs(actual - expected) > 0.01: + logger.warning( + "Annual total mismatch for component '{}', year {}: " + "expected={}, actual={}", + component.name, + year, + expected, + actual, + ) + + return total_injected + + def _build_profile_sql( + self, + component: CustomDemandComponent, + scenario_name: str, + model_years_str: str, + ) -> str: + """Build the SQL SELECT that produces hourly rows for a custom demand component. + + Returns a SQL SELECT statement (without INSERT INTO) that produces rows matching + the {scenario}.energy_projection schema: timestamp, model_year, sector, geography, + metric, value. + """ + staging_table = f"stride.custom__{component.name}__{scenario_name}__annual" + profile = component.load_profile + + if profile == "flat": + return self._build_flat_profile_sql( + component, scenario_name, staging_table, model_years_str + ) + elif profile.startswith("sector:"): + ref_sector = profile.split(":", 1)[1] + self._validate_load_shape_reference( + scenario_name, "sector", ref_sector, model_years_str + ) + return self._build_sector_reference_profile_sql( + component, scenario_name, staging_table, model_years_str, ref_sector + ) + elif profile.startswith("enduse:"): + ref_enduse = profile.split(":", 1)[1] + self._validate_load_shape_reference( + scenario_name, "enduse", ref_enduse, model_years_str + ) + return self._build_enduse_reference_profile_sql( + component, scenario_name, staging_table, model_years_str, ref_enduse + ) + else: + # Treat as a file path to an 8760 hourly profile CSV/Parquet + profile_path = Path(profile) + if not profile_path.is_absolute(): + profile_path = self._path / profile_path + if not profile_path.exists(): + msg = ( + f"Custom profile file '{profile}' not found " + f"(resolved to {profile_path})" + ) + raise InvalidParameter(msg) + return self._build_file_profile_sql( + component, scenario_name, staging_table, model_years_str, profile_path + ) + + def _build_flat_profile_sql( + self, + component: CustomDemandComponent, + scenario_name: str, + staging_table: str, + model_years_str: str, + ) -> str: + """Build SQL for flat (uniform) hourly distribution.""" + return f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging_table} + WHERE model_year IN ({model_years_str}) + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year + FROM {scenario_name}.energy_projection + ) + SELECT + ht.timestamp, + ht.model_year, + '{self._config.country}' AS geography, + '{component.sector}' AS sector, + '{component.metric}' AS metric, + ad.annual_mwh / 8760.0 AS value, + '{scenario_name}' AS scenario + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + """ + + def _validate_load_shape_reference( + self, + scenario_name: str, + dimension: str, + ref_value: str, + model_years_str: str, + ) -> None: + """Validate that a sector or enduse reference exists in load_shapes_expanded. + + Raises InvalidParameter if the reference value is not found. + """ + count = self._con.sql( + f"SELECT COUNT(*) FROM {scenario_name}.load_shapes_expanded " + f"WHERE {dimension} = '{ref_value}' " + f"AND model_year IN ({model_years_str})" + ).fetchone()[0] + if count == 0: + available = sorted( + row[0] + for row in self._con.sql( + f"SELECT DISTINCT {dimension} FROM {scenario_name}.load_shapes_expanded " + f"WHERE model_year IN ({model_years_str})" + ).fetchall() + ) + msg = ( + f"Reference profile '{dimension}:{ref_value}' not found in " + f"load_shapes_expanded. Available {dimension}s: {available}" + ) + raise InvalidParameter(msg) + + def _build_sector_reference_profile_sql( + self, + component: CustomDemandComponent, + scenario_name: str, + staging_table: str, + model_years_str: str, + ref_sector: str, + ) -> str: + """Build SQL using a sector's aggregate load shape as the hourly profile. + + Aggregates all enduses for the given sector, normalizes to fraction per + model_year, then multiplies by annual MWh. + """ + return f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging_table} + WHERE model_year IN ({model_years_str}) + ), + aggregated_shape AS ( + SELECT + timestamp, + model_year, + SUM(adjusted_value) AS total_adjusted_value + FROM {scenario_name}.load_shapes_expanded + WHERE sector = '{ref_sector}' + AND model_year IN ({model_years_str}) + GROUP BY timestamp, model_year + ), + reference_shape AS ( + SELECT + timestamp, + model_year, + total_adjusted_value / SUM(total_adjusted_value) OVER ( + PARTITION BY model_year + ) AS fraction + FROM aggregated_shape + ) + SELECT + rs.timestamp, + rs.model_year, + '{self._config.country}' AS geography, + '{component.sector}' AS sector, + '{component.metric}' AS metric, + ad.annual_mwh * rs.fraction AS value, + '{scenario_name}' AS scenario + FROM reference_shape rs + JOIN annual_data ad ON rs.model_year = ad.model_year + """ + + def _build_enduse_reference_profile_sql( + self, + component: CustomDemandComponent, + scenario_name: str, + staging_table: str, + model_years_str: str, + ref_enduse: str, + ) -> str: + """Build SQL using an enduse's aggregate load shape as the hourly profile. + + Aggregates all sectors for the given enduse, normalizes to fraction per + model_year, then multiplies by annual MWh. + """ + return f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging_table} + WHERE model_year IN ({model_years_str}) + ), + aggregated_shape AS ( + SELECT + timestamp, + model_year, + SUM(adjusted_value) AS total_adjusted_value + FROM {scenario_name}.load_shapes_expanded + WHERE enduse = '{ref_enduse}' + AND model_year IN ({model_years_str}) + GROUP BY timestamp, model_year + ), + reference_shape AS ( + SELECT + timestamp, + model_year, + total_adjusted_value / SUM(total_adjusted_value) OVER ( + PARTITION BY model_year + ) AS fraction + FROM aggregated_shape + ) + SELECT + rs.timestamp, + rs.model_year, + '{self._config.country}' AS geography, + '{component.sector}' AS sector, + '{component.metric}' AS metric, + ad.annual_mwh * rs.fraction AS value, + '{scenario_name}' AS scenario + FROM reference_shape rs + JOIN annual_data ad ON rs.model_year = ad.model_year + """ + + def _build_file_profile_sql( + self, + component: CustomDemandComponent, + scenario_name: str, + staging_table: str, + model_years_str: str, + profile_path: Path, + ) -> str: + """Build SQL using a user-provided hourly profile CSV. + + The profile file must have a `value` column with 8760 rows (one per hour). + It is joined positionally with the hourly timestamps from energy_projection. + The profile is normalized so that annual totals match the input data. + """ + profile_table = ( + f"stride.custom__{component.name}__{scenario_name}__profile" + ) + create_table_from_file( + self._con, profile_table, profile_path, replace=True + ) + + # Validate the profile has a value column + columns = [ + col[0] + for col in self._con.sql(f"DESCRIBE {profile_table}").fetchall() + ] + if "value" not in columns: + msg = ( + f"Custom profile file must have a 'value' column. " + f"Found: {columns}" + ) + raise InvalidParameter(msg) + + row_count = self._con.sql( + f"SELECT COUNT(*) FROM {profile_table}" + ).fetchone()[0] + if row_count != 8760: + msg = ( + f"Custom profile file must have exactly 8760 rows " + f"(got {row_count})" + ) + raise InvalidParameter(msg) + + return f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging_table} + WHERE model_year IN ({model_years_str}) + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year, + ROW_NUMBER() OVER ( + PARTITION BY model_year ORDER BY timestamp + ) AS hour_idx + FROM {scenario_name}.energy_projection + ), + profile_data AS ( + SELECT + ROW_NUMBER() OVER (ORDER BY rowid) AS hour_idx, + value AS profile_value + FROM {profile_table} + ), + profile_normalized AS ( + SELECT + hour_idx, + profile_value / SUM(profile_value) OVER () AS fraction + FROM profile_data + ) + SELECT + ht.timestamp, + ht.model_year, + '{self._config.country}' AS geography, + '{component.sector}' AS sector, + '{component.metric}' AS metric, + ad.annual_mwh * pn.fraction AS value, + '{scenario_name}' AS scenario + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + JOIN profile_normalized pn ON ht.hour_idx = pn.hour_idx + """ + def export_energy_projection( self, filename: Path = Path("energy_projection.csv"), overwrite: bool = False ) -> None: diff --git a/src/stride/ui/app.py b/src/stride/ui/app.py index 2e2e380..3ccd6f8 100644 --- a/src/stride/ui/app.py +++ b/src/stride/ui/app.py @@ -120,6 +120,7 @@ def create_fresh_color_manager( scenarios: list[str], *, ui_theme: str = "light", + sectors: list[str] | None = None, ) -> ColorManager: """Create a fresh ColorManager instance, bypassing the singleton. @@ -145,7 +146,7 @@ def create_fresh_color_manager( ColorManager.__init__(color_manager, palette) color_manager.initialize_colors( scenarios=scenarios, - sectors=literal_to_list(Sectors), + sectors=sectors if sectors is not None else literal_to_list(Sectors), end_uses=list(palette.end_uses.keys()), ) @@ -200,15 +201,18 @@ def load_project(project_path: str) -> tuple[bool, str]: # Create a fresh color manager for this project palette = project.palette.copy() + dynamic_sectors = None try: + dynamic_sectors = data_handler.get_unique_sectors() palette.merge_with_project_dimensions( - sectors=data_handler.get_unique_sectors(), + sectors=dynamic_sectors, end_uses=data_handler.get_unique_end_uses(), ) except Exception as e: logger.warning(f"Could not populate sectors/end_uses: {e}") color_manager = create_fresh_color_manager( - palette, data_handler.scenarios, ui_theme=ui_theme + palette, data_handler.scenarios, ui_theme=ui_theme, + sectors=dynamic_sectors, ) plotter = StridePlots(color_manager, template=current_template) @@ -865,9 +869,11 @@ def on_palette_change(palette: ColorPalette, palette_type: str, palette_name: st data_handler = APIClient(cached_project) # Ensure sectors/end_uses are current + dynamic_sectors = None try: + dynamic_sectors = data_handler.get_unique_sectors() palette_copy.merge_with_project_dimensions( - sectors=data_handler.get_unique_sectors(), + sectors=dynamic_sectors, end_uses=data_handler.get_unique_end_uses(), ) except Exception as e: @@ -878,6 +884,7 @@ def on_palette_change(palette: ColorPalette, palette_type: str, palette_name: st palette_copy, data_handler.scenarios, ui_theme=ui_mode, + sectors=dynamic_sectors, ) plotter = StridePlots(color_manager, template=current_template) @@ -1649,9 +1656,11 @@ def _on_palette_change_no_project( data_handler = APIClient(cached_project) # Ensure sectors/end_uses are current + dynamic_sectors = None try: + dynamic_sectors = data_handler.get_unique_sectors() palette_copy.merge_with_project_dimensions( - sectors=data_handler.get_unique_sectors(), + sectors=dynamic_sectors, end_uses=data_handler.get_unique_end_uses(), ) except Exception as e: @@ -1661,6 +1670,7 @@ def _on_palette_change_no_project( palette_copy, data_handler.scenarios, ui_theme=ui_mode, + sectors=dynamic_sectors, ) new_plotter = StridePlots(new_color_manager, template=current_template) diff --git a/src/stride/ui/settings/callbacks.py b/src/stride/ui/settings/callbacks.py index 9e0ab3c..2554c17 100644 --- a/src/stride/ui/settings/callbacks.py +++ b/src/stride/ui/settings/callbacks.py @@ -962,12 +962,10 @@ def save_max_cached_projects( className="text-danger", ) - from stride.ui.app import _evict_oldest_project, set_max_cached_projects_override + from stride.ui.app import _evict_oldest_project # Persist to config file set_max_cached_projects(n) - # Also update the runtime override so it takes effect immediately - set_max_cached_projects_override(n) # Trigger eviction if current cache exceeds new limit _evict_oldest_project() diff --git a/tests/test_api.py b/tests/test_api.py index 311acab..32ea40e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -128,7 +128,7 @@ def test_annual_consumption_enduse_no_duplicates(api_client: APIClient) -> None: assert ( not duplicates.any() ), f"Duplicate (scenario, year, metric) rows found:\n{df[duplicates]}" - assert (df["value"] > 0).all(), "All consumption values should be positive" + assert (df["value"] >= 0).all(), "All consumption values should be non-negative" def test_annual_consumption_breakdown_sums_to_total(api_client: APIClient) -> None: @@ -196,7 +196,7 @@ def test_get_annual_peak_demand_enduse_no_duplicates(api_client: APIClient) -> N assert ( not duplicates.any() ), f"Duplicate (scenario, year, metric) rows found:\n{df[duplicates]}" - assert (df["value"] > 0).all(), "All peak demand values should be positive" + assert (df["value"] >= 0).all(), "All peak demand values should be non-negative" def test_get_secondary_metric(api_client: APIClient) -> None: diff --git a/tests/test_app_cache.py b/tests/test_app_cache.py index 7664451..95681c3 100644 --- a/tests/test_app_cache.py +++ b/tests/test_app_cache.py @@ -626,6 +626,31 @@ def test_persistence_and_eviction(self, tmp_path: Path) -> None: saved = json.loads(config_file.read_text()) assert saved["max_cached_projects"] == 2 + def test_settings_save_does_not_set_override(self, tmp_path: Path) -> None: + """Regression: saving from Settings must not set _max_cached_projects_override. + + The save callback should only persist to config and trigger eviction, + not set the runtime CLI override which would disable the input. + """ + import json + + config_file = tmp_path / "config.json" + + assert app_module._max_cached_projects_override is None + + with patch("stride.config.get_stride_config_path", return_value=config_file): + from stride.config import set_max_cached_projects + + # Replicate what the fixed save callback does (without override) + set_max_cached_projects(5) + app_module._evict_oldest_project() + + # The override must remain None so the input stays editable + assert app_module._max_cached_projects_override is None + # Config should be persisted + saved = json.loads(config_file.read_text()) + assert saved["max_cached_projects"] == 5 + # =================================================================== # Tests for settings layout override display logic (layout.py) diff --git a/tests/test_custom_demand.py b/tests/test_custom_demand.py new file mode 100644 index 0000000..b19b321 --- /dev/null +++ b/tests/test_custom_demand.py @@ -0,0 +1,778 @@ +"""Tests for custom demand component functionality. + +Covers: +- CustomDemandComponent model validation (3a) +- Flat profile injection (3b) +- Sector/enduse reference profile injection (3c) +- File-based 8760 profile injection (3c) +- Edge cases: bad schema, missing years, invalid references (3d) +- CLI commands: add, list, remove (3d) +""" + +from __future__ import annotations + +import csv +from pathlib import Path + +import duckdb +import pytest +from chronify.exceptions import InvalidParameter +from click.testing import CliRunner +from pydantic import ValidationError + +from stride.cli.stride import cli +from stride.io import create_table_from_file +from stride.models import CustomDemandComponent + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +HOURS_PER_YEAR = 4 # small count to keep tests fast +MODEL_YEARS = [2025, 2030] + + +def _make_annual_csv(tmp_path: Path, name: str = "annual.csv", years: list[int] | None = None, + values: list[float] | None = None) -> Path: + """Write a simple model_year,value CSV and return its path.""" + years = years or MODEL_YEARS + values = values or [float(y - 2024) * 1e6 for y in years] + p = tmp_path / name + with open(p, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["model_year", "value"]) + for y, v in zip(years, values): + w.writerow([y, v]) + return p + + +def _make_profile_csv(tmp_path: Path, n_rows: int = 8760, name: str = "profile.csv") -> Path: + """Write an 8760-row profile CSV with linearly increasing values.""" + p = tmp_path / name + with open(p, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["value"]) + for i in range(1, n_rows + 1): + w.writerow([float(i)]) + return p + + +def _setup_db( + con: duckdb.DuckDBPyConnection, + scenario: str = "baseline", + hours: int = HOURS_PER_YEAR, + model_years: list[int] | None = None, + with_load_shapes: bool = False, +) -> None: + """Create minimal energy_projection table and optionally load_shapes_expanded.""" + model_years = model_years or MODEL_YEARS + con.sql(f"CREATE SCHEMA IF NOT EXISTS {scenario}") + con.sql("CREATE SCHEMA IF NOT EXISTS stride") + + # Per-scenario energy_projection table (materialized, as it would be after + # the view→table conversion in compute_energy_projection). + con.sql(f""" + CREATE TABLE {scenario}.energy_projection ( + timestamp TIMESTAMP, + model_year BIGINT, + geography VARCHAR, + sector VARCHAR, + metric VARCHAR, + value DOUBLE, + scenario VARCHAR + ) + """) + + for yr in model_years: + if hours <= 24: + rows = ", ".join( + f"(TIMESTAMP '2016-01-01 {h:02d}:00:00', {yr}, 'Germany', 'Residential', " + f"'heating', 100.0, '{scenario}')" + for h in range(hours) + ) + con.sql(f"INSERT INTO {scenario}.energy_projection VALUES {rows}") + else: + # Generate multi-day timestamps for large hour counts (e.g., 8760) + con.sql(f""" + INSERT INTO {scenario}.energy_projection + SELECT + TIMESTAMP '2016-01-01 00:00:00' + INTERVAL (i) HOUR AS timestamp, + {yr} AS model_year, + 'Germany' AS geography, + 'Residential' AS sector, + 'heating' AS metric, + 100.0 AS value, + '{scenario}' AS scenario + FROM generate_series(0, {hours - 1}) AS t(i) + """) + + if with_load_shapes: + con.sql(f""" + CREATE TABLE {scenario}.load_shapes_expanded ( + geography VARCHAR, model_year BIGINT, sector VARCHAR, enduse VARCHAR, + timestamp TIMESTAMP, weather_year INT, + load_shape_value DOUBLE, multiplier DOUBLE, adjusted_value DOUBLE + ) + """) + # Residential: heating [40,30,20,10], cooling [5,10,30,55] + # Commercial: heating [20,30,30,20] + patterns = [ + ("Residential", "heating", [40, 30, 20, 10]), + ("Residential", "cooling", [5, 10, 30, 55]), + ("Commercial", "heating", [20, 30, 30, 20]), + ] + for yr in model_years: + for sector, enduse, vals in patterns: + for h, v in enumerate(vals[:hours]): + con.sql(f""" + INSERT INTO {scenario}.load_shapes_expanded VALUES + ('Germany', {yr}, '{sector}', '{enduse}', + '2016-01-01 {h:02d}:00:00', 2016, {v}, 1.0, {v}) + """) + + +# --------------------------------------------------------------------------- +# 3a: Model validation +# --------------------------------------------------------------------------- + + +class TestCustomDemandComponentModel: + def test_valid_component(self, tmp_path: Path) -> None: + csv_path = _make_annual_csv(tmp_path) + c = CustomDemandComponent( + name="heat_pumps", sector="Heat Pumps", data_file=csv_path, + ) + assert c.load_profile == "flat" + assert c.metric == "other" + + def test_name_must_be_identifier(self, tmp_path: Path) -> None: + csv_path = _make_annual_csv(tmp_path) + with pytest.raises(ValidationError, match="valid Python identifier"): + CustomDemandComponent( + name="bad-name", sector="X", data_file=csv_path, + ) + + def test_name_with_spaces_rejected(self, tmp_path: Path) -> None: + csv_path = _make_annual_csv(tmp_path) + with pytest.raises(ValidationError, match="valid Python identifier"): + CustomDemandComponent( + name="bad name", sector="X", data_file=csv_path, + ) + + def test_custom_profile_options(self, tmp_path: Path) -> None: + csv_path = _make_annual_csv(tmp_path) + c = CustomDemandComponent( + name="x", sector="X", data_file=csv_path, + load_profile="sector:Residential", metric="heating", + ) + assert c.load_profile == "sector:Residential" + assert c.metric == "heating" + + +# --------------------------------------------------------------------------- +# 3b: Flat profile injection +# --------------------------------------------------------------------------- + + +class TestFlatProfileInjection: + def test_flat_injection_row_count(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con) + csv_path = _make_annual_csv(tmp_path) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging} + WHERE model_year IN (2025, 2030) + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year + FROM baseline.energy_projection + ) + SELECT + ht.timestamp, ht.model_year, + 'Germany' AS geography, + 'Data Centers' AS sector, + 'other' AS metric, + ad.annual_mwh / {HOURS_PER_YEAR}.0 AS value, + 'baseline' AS scenario + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + count = con.sql( + "SELECT COUNT(*) FROM baseline.energy_projection " + "WHERE sector = 'Data Centers'" + ).fetchone()[0] + assert count == HOURS_PER_YEAR * len(MODEL_YEARS) + + def test_flat_injection_annual_totals(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con) + csv_path = _make_annual_csv(tmp_path, values=[1_000_000, 2_000_000]) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging} + WHERE model_year IN (2025, 2030) + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year + FROM baseline.energy_projection + ) + SELECT + ht.timestamp, ht.model_year, + 'Germany' AS geography, + 'DC' AS sector, + 'other' AS metric, + ad.annual_mwh / {HOURS_PER_YEAR}.0 AS value, + 'baseline' AS scenario + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + for yr, expected in [(2025, 1_000_000), (2030, 2_000_000)]: + actual = con.sql( + f"SELECT SUM(value) FROM baseline.energy_projection " + f"WHERE sector = 'DC' AND model_year = {yr}" + ).fetchone()[0] + assert abs(actual - expected) < 0.01 + + def test_flat_injection_preserves_existing(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con) + csv_path = _make_annual_csv(tmp_path) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + before = con.sql( + "SELECT COUNT(*) FROM baseline.energy_projection WHERE sector = 'Residential'" + ).fetchone()[0] + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh + FROM {staging} WHERE model_year IN (2025, 2030) + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year FROM baseline.energy_projection + ) + SELECT ht.timestamp, ht.model_year, 'Germany', 'DC', 'other', + ad.annual_mwh / {HOURS_PER_YEAR}.0, 'baseline' + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + after = con.sql( + "SELECT COUNT(*) FROM baseline.energy_projection WHERE sector = 'Residential'" + ).fetchone()[0] + assert after == before + + def test_idempotent_delete_reinsert(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con) + csv_path = _make_annual_csv(tmp_path) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + SELECT ht.timestamp, ht.model_year, 'Germany', 'DC', 'other', + ad.annual_mwh / {HOURS_PER_YEAR}.0, 'baseline' + FROM (SELECT DISTINCT timestamp, model_year FROM baseline.energy_projection) ht + JOIN (SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year IN (2025, 2030)) ad + ON ht.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + con.sql("DELETE FROM baseline.energy_projection WHERE sector = 'DC'") + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + count = con.sql( + "SELECT COUNT(*) FROM baseline.energy_projection WHERE sector = 'DC'" + ).fetchone()[0] + assert count == HOURS_PER_YEAR * len(MODEL_YEARS) + + +# --------------------------------------------------------------------------- +# 3c: Reference profile injection (sector/enduse) +# --------------------------------------------------------------------------- + + +class TestReferenceProfileInjection: + def test_sector_profile_annual_totals(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + csv_path = _make_annual_csv(tmp_path, values=[1_000_000, 2_000_000]) + staging = "stride.custom__hp__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year IN (2025, 2030) + ), + aggregated_shape AS ( + SELECT timestamp, model_year, SUM(adjusted_value) AS total + FROM baseline.load_shapes_expanded + WHERE sector = 'Residential' AND model_year IN (2025, 2030) + GROUP BY timestamp, model_year + ), + reference_shape AS ( + SELECT timestamp, model_year, + total / SUM(total) OVER (PARTITION BY model_year) AS fraction + FROM aggregated_shape + ) + SELECT rs.timestamp, rs.model_year, 'Germany', 'HP', 'heating', + ad.annual_mwh * rs.fraction, 'baseline' + FROM reference_shape rs + JOIN annual_data ad ON rs.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + for yr, expected in [(2025, 1_000_000), (2030, 2_000_000)]: + actual = con.sql( + f"SELECT SUM(value) FROM baseline.energy_projection " + f"WHERE sector = 'HP' AND model_year = {yr}" + ).fetchone()[0] + assert abs(actual - expected) < 0.01 + + def test_sector_profile_shape_not_flat(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + csv_path = _make_annual_csv(tmp_path, values=[1_000_000, 2_000_000]) + staging = "stride.custom__hp__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year IN (2025, 2030) + ), + agg AS ( + SELECT timestamp, model_year, SUM(adjusted_value) AS total + FROM baseline.load_shapes_expanded + WHERE sector = 'Residential' AND model_year IN (2025, 2030) + GROUP BY timestamp, model_year + ), + ref AS ( + SELECT timestamp, model_year, + total / SUM(total) OVER (PARTITION BY model_year) AS fraction + FROM agg + ) + SELECT ref.timestamp, ref.model_year, 'Germany', 'HP', 'heating', + ad.annual_mwh * ref.fraction, 'baseline' + FROM ref JOIN annual_data ad ON ref.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + vals = con.sql( + "SELECT value FROM baseline.energy_projection " + "WHERE sector = 'HP' AND model_year = 2025 ORDER BY timestamp" + ).fetchall() + hourly = [r[0] for r in vals] + # Residential aggregate: heating[40,30,20,10]+cooling[5,10,30,55]=[45,40,50,65] + # Not all equal → not flat + assert len(set(hourly)) > 1, "Sector profile should NOT be flat" + + def test_enduse_profile_annual_totals(self, tmp_path: Path) -> None: + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + csv_path = _make_annual_csv(tmp_path, values=[1_000_000, 2_000_000]) + staging = "stride.custom__hp__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year IN (2025, 2030) + ), + agg AS ( + SELECT timestamp, model_year, SUM(adjusted_value) AS total + FROM baseline.load_shapes_expanded + WHERE enduse = 'heating' AND model_year IN (2025, 2030) + GROUP BY timestamp, model_year + ), + ref AS ( + SELECT timestamp, model_year, + total / SUM(total) OVER (PARTITION BY model_year) AS fraction + FROM agg + ) + SELECT ref.timestamp, ref.model_year, 'Germany', 'HP', 'heating', + ad.annual_mwh * ref.fraction, 'baseline' + FROM ref JOIN annual_data ad ON ref.model_year = ad.model_year + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + for yr, expected in [(2025, 1_000_000), (2030, 2_000_000)]: + actual = con.sql( + f"SELECT SUM(value) FROM baseline.energy_projection " + f"WHERE sector = 'HP' AND model_year = {yr}" + ).fetchone()[0] + assert abs(actual - expected) < 0.01 + + def test_sector_and_enduse_profiles_differ(self, tmp_path: Path) -> None: + """sector:Residential and enduse:heating should produce different shapes.""" + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + csv_path = _make_annual_csv(tmp_path, values=[1_000_000, 2_000_000]) + staging = "stride.custom__hp__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + # sector:Residential + sector_sql = f""" + WITH ad AS (SELECT model_year, value AS mwh FROM {staging} WHERE model_year=2025), + agg AS (SELECT timestamp, model_year, SUM(adjusted_value) AS t + FROM baseline.load_shapes_expanded WHERE sector='Residential' AND model_year=2025 + GROUP BY timestamp, model_year), + ref AS (SELECT timestamp, model_year, t/SUM(t) OVER (PARTITION BY model_year) AS f FROM agg) + SELECT ref.timestamp, ad.mwh * ref.f AS value + FROM ref JOIN ad ON ref.model_year = ad.model_year ORDER BY ref.timestamp + """ + sector_vals = [r[1] for r in con.sql(sector_sql).fetchall()] + + # enduse:heating + enduse_sql = f""" + WITH ad AS (SELECT model_year, value AS mwh FROM {staging} WHERE model_year=2025), + agg AS (SELECT timestamp, model_year, SUM(adjusted_value) AS t + FROM baseline.load_shapes_expanded WHERE enduse='heating' AND model_year=2025 + GROUP BY timestamp, model_year), + ref AS (SELECT timestamp, model_year, t/SUM(t) OVER (PARTITION BY model_year) AS f FROM agg) + SELECT ref.timestamp, ad.mwh * ref.f AS value + FROM ref JOIN ad ON ref.model_year = ad.model_year ORDER BY ref.timestamp + """ + enduse_vals = [r[1] for r in con.sql(enduse_sql).fetchall()] + + assert sector_vals != enduse_vals + + +# --------------------------------------------------------------------------- +# 3c (cont): File-based 8760 profile +# --------------------------------------------------------------------------- + + +class TestFileProfileInjection: + def test_file_profile_annual_totals(self, tmp_path: Path) -> None: + """An 8760-row profile file should distribute annual energy correctly.""" + con = duckdb.connect() + n_hours = 8760 + _setup_db(con, hours=n_hours, model_years=[2025]) + csv_path = _make_annual_csv(tmp_path, years=[2025], values=[1_000_000.0]) + profile_path = _make_profile_csv(tmp_path, n_rows=n_hours) + + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + profile_table = "stride.custom__dc__baseline__profile" + create_table_from_file(con, profile_table, profile_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year = 2025 + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year, + ROW_NUMBER() OVER (PARTITION BY model_year ORDER BY timestamp) AS hour_idx + FROM baseline.energy_projection + ), + profile_data AS ( + SELECT ROW_NUMBER() OVER (ORDER BY rowid) AS hour_idx, + value AS profile_value FROM {profile_table} + ), + profile_normalized AS ( + SELECT hour_idx, + profile_value / SUM(profile_value) OVER () AS fraction + FROM profile_data + ) + SELECT ht.timestamp, ht.model_year, 'Germany', 'DC', 'other', + ad.annual_mwh * pn.fraction, 'baseline' + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + JOIN profile_normalized pn ON ht.hour_idx = pn.hour_idx + """ + con.sql(f"INSERT INTO baseline.energy_projection {sql}") + + actual = con.sql( + "SELECT SUM(value) FROM baseline.energy_projection " + "WHERE sector = 'DC' AND model_year = 2025" + ).fetchone()[0] + assert abs(actual - 1_000_000.0) < 0.01 + + def test_file_profile_shape_not_flat(self, tmp_path: Path) -> None: + """A non-uniform profile file should produce non-uniform hourly values.""" + con = duckdb.connect() + n_hours = 8760 + _setup_db(con, hours=n_hours, model_years=[2025]) + csv_path = _make_annual_csv(tmp_path, years=[2025], values=[1_000_000.0]) + profile_path = _make_profile_csv(tmp_path, n_rows=n_hours) + + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + profile_table = "stride.custom__dc__baseline__profile" + create_table_from_file(con, profile_table, profile_path, replace=True) + + sql = f""" + WITH annual_data AS ( + SELECT model_year, value AS annual_mwh FROM {staging} + WHERE model_year = 2025 + ), + hourly_timestamps AS ( + SELECT DISTINCT timestamp, model_year, + ROW_NUMBER() OVER (PARTITION BY model_year ORDER BY timestamp) AS hour_idx + FROM baseline.energy_projection + ), + profile_data AS ( + SELECT ROW_NUMBER() OVER (ORDER BY rowid) AS hour_idx, + value AS profile_value FROM {profile_table} + ), + profile_normalized AS ( + SELECT hour_idx, + profile_value / SUM(profile_value) OVER () AS fraction + FROM profile_data + ) + SELECT ht.timestamp, ht.model_year, 'Germany', 'DC', 'other', + ad.annual_mwh * pn.fraction, 'baseline' + FROM hourly_timestamps ht + JOIN annual_data ad ON ht.model_year = ad.model_year + JOIN profile_normalized pn ON ht.hour_idx = pn.hour_idx + """ + result = con.sql(sql).fetchall() + vals = [r[5] for r in result] + assert len(set(vals)) > 1, "File profile should NOT produce flat output" + + +# --------------------------------------------------------------------------- +# 3d: Edge cases & validation +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_missing_value_column(self, tmp_path: Path) -> None: + """CSV without 'value' column should be rejected.""" + csv_path = tmp_path / "bad.csv" + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["model_year", "amount"]) + w.writerow([2025, 100]) + + con = duckdb.connect() + _setup_db(con) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + columns = [col[0] for col in con.sql(f"DESCRIBE {staging}").fetchall()] + assert "value" not in columns + + def test_missing_model_year_column(self, tmp_path: Path) -> None: + """CSV without 'model_year' column should be rejected.""" + csv_path = tmp_path / "bad.csv" + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["year", "value"]) + w.writerow([2025, 100]) + + con = duckdb.connect() + _setup_db(con) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + columns = [col[0] for col in con.sql(f"DESCRIBE {staging}").fetchall()] + assert "model_year" not in columns + + def test_missing_years_in_data(self, tmp_path: Path) -> None: + """CSV missing required model years should be detectable.""" + csv_path = _make_annual_csv(tmp_path, years=[2025], values=[1_000_000]) + con = duckdb.connect() + _setup_db(con) + staging = "stride.custom__dc__baseline__annual" + create_table_from_file(con, staging, csv_path, replace=True) + + available = { + row[0] for row in con.sql( + f"SELECT DISTINCT model_year FROM {staging}" + ).fetchall() + } + required = set(MODEL_YEARS) + assert required - available == {2030} + + def test_unknown_sector_reference(self, tmp_path: Path) -> None: + """Referencing a non-existent sector in load_shapes_expanded returns 0 rows.""" + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + + count = con.sql( + "SELECT COUNT(*) FROM baseline.load_shapes_expanded " + "WHERE sector = 'Manufacturing' AND model_year IN (2025, 2030)" + ).fetchone()[0] + assert count == 0 + + def test_unknown_enduse_reference(self, tmp_path: Path) -> None: + """Referencing a non-existent enduse in load_shapes_expanded returns 0 rows.""" + con = duckdb.connect() + _setup_db(con, with_load_shapes=True) + + count = con.sql( + "SELECT COUNT(*) FROM baseline.load_shapes_expanded " + "WHERE enduse = 'transport' AND model_year IN (2025, 2030)" + ).fetchone()[0] + assert count == 0 + + def test_file_profile_wrong_row_count(self, tmp_path: Path) -> None: + """Profile CSV with != 8760 rows should be detectable.""" + profile_path = _make_profile_csv(tmp_path, n_rows=100) + con = duckdb.connect() + con.sql("CREATE SCHEMA IF NOT EXISTS stride") + create_table_from_file(con, "stride.profile", profile_path, replace=True) + count = con.sql("SELECT COUNT(*) FROM stride.profile").fetchone()[0] + assert count != 8760 + + def test_file_profile_missing_value_column(self, tmp_path: Path) -> None: + """Profile CSV without 'value' column should be detectable.""" + csv_path = tmp_path / "bad_profile.csv" + with open(csv_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["load"]) + for i in range(8760): + w.writerow([i]) + + con = duckdb.connect() + con.sql("CREATE SCHEMA IF NOT EXISTS stride") + create_table_from_file(con, "stride.profile", csv_path, replace=True) + columns = [col[0] for col in con.sql("DESCRIBE stride.profile").fetchall()] + assert "value" not in columns + + +# --------------------------------------------------------------------------- +# 3d (cont): CLI commands +# --------------------------------------------------------------------------- + + +class TestCustomDemandCLI: + def test_list_empty(self, copy_project_input_data: tuple[Path, Path, Path]) -> None: + """Listing components on a project with none configured.""" + scratch, data_dir, config_path = copy_project_input_data + runner = CliRunner() + result = runner.invoke(cli, [ + "projects", "create", str(config_path), + "--directory", str(scratch), + "--dataset", "global-test", + ]) + assert result.exit_code == 0, result.output + + project_dir = scratch / "test_project" + result = runner.invoke(cli, ["custom-demand", "list", str(project_dir)]) + assert result.exit_code == 0, result.output + assert "No custom demand components" in result.output + + def test_add_list_remove_cycle( + self, copy_project_input_data: tuple[Path, Path, Path], tmp_path: Path, + ) -> None: + """Full add → list → remove lifecycle.""" + scratch, data_dir, config_path = copy_project_input_data + runner = CliRunner() + result = runner.invoke(cli, [ + "projects", "create", str(config_path), + "--directory", str(scratch), + "--dataset", "global-test", + ]) + assert result.exit_code == 0, result.output + project_dir = scratch / "test_project" + + # Create annual CSV with correct model years for this project + csv_path = _make_annual_csv( + tmp_path, + years=[2025, 2030, 2035, 2040, 2045, 2050], + values=[1e6, 2e6, 3e6, 4e6, 5e6, 6e6], + ) + + # Add + result = runner.invoke(cli, [ + "custom-demand", "add", str(project_dir), + "--name", "data_centers", + "--sector", "Data Centers", + "--data-file", str(csv_path), + ]) + assert result.exit_code == 0, result.output + assert "Added custom demand component" in result.output + + # List + result = runner.invoke(cli, ["custom-demand", "list", str(project_dir)]) + assert result.exit_code == 0, result.output + assert "data_centers" in result.output + assert "Data Centers" in result.output + + # Remove + result = runner.invoke(cli, [ + "custom-demand", "remove", str(project_dir), + "--name", "data_centers", + ]) + assert result.exit_code == 0, result.output + assert "Removed" in result.output + + # List again — empty + result = runner.invoke(cli, ["custom-demand", "list", str(project_dir)]) + assert result.exit_code == 0, result.output + assert "No custom demand components" in result.output + + def test_add_duplicate_rejected( + self, copy_project_input_data: tuple[Path, Path, Path], tmp_path: Path, + ) -> None: + """Adding a component with an existing name should fail.""" + scratch, data_dir, config_path = copy_project_input_data + runner = CliRunner() + result = runner.invoke(cli, [ + "projects", "create", str(config_path), + "--directory", str(scratch), + "--dataset", "global-test", + ]) + assert result.exit_code == 0, result.output + project_dir = scratch / "test_project" + + csv_path = _make_annual_csv( + tmp_path, + years=[2025, 2030, 2035, 2040, 2045, 2050], + values=[1e6, 2e6, 3e6, 4e6, 5e6, 6e6], + ) + + # First add + result = runner.invoke(cli, [ + "custom-demand", "add", str(project_dir), + "--name", "dc", "--sector", "DC", + "--data-file", str(csv_path), + ]) + assert result.exit_code == 0, result.output + + # Duplicate add + result = runner.invoke(cli, [ + "custom-demand", "add", str(project_dir), + "--name", "dc", "--sector", "DC", + "--data-file", str(csv_path), + ]) + assert result.exit_code != 0 + + def test_remove_nonexistent_rejected( + self, copy_project_input_data: tuple[Path, Path, Path], + ) -> None: + """Removing a component that doesn't exist should fail.""" + scratch, data_dir, config_path = copy_project_input_data + runner = CliRunner() + result = runner.invoke(cli, [ + "projects", "create", str(config_path), + "--directory", str(scratch), + "--dataset", "global-test", + ]) + assert result.exit_code == 0, result.output + project_dir = scratch / "test_project" + + result = runner.invoke(cli, [ + "custom-demand", "remove", str(project_dir), + "--name", "nonexistent", + ]) + assert result.exit_code != 0 diff --git a/tests/test_energy_projection.py b/tests/test_energy_projection.py index 113e5b9..fd4a3a6 100644 --- a/tests/test_energy_projection.py +++ b/tests/test_energy_projection.py @@ -130,6 +130,7 @@ def compute_energy_projection_com_ind_tra_with_ev( For use_ev_projection=True: - Commercial and Industrial use standard energy intensity regression - Transportation/Road uses EV stock * km/vehicle * Wh/km calculation + - EV energy is tagged with metric='ev_charging' within Transportation sector """ model_years_tuple = tuple(model_years) @@ -169,21 +170,30 @@ def compute_energy_projection_com_ind_tra_with_ev( # Exclude Transportation/Road from base (will be replaced by EV calculation) stride_annual_energy_non_ev = con.sql( # noqa: F841 """ - SELECT * FROM stride_annual_energy_base + SELECT geography, model_year, sector, subsector, stride_annual_total, + 'base' AS energy_source + FROM stride_annual_energy_base WHERE NOT (sector = 'Transportation' AND subsector = 'Road') """ ) # Calculate EV annual energy - ev_annual_energy = compute_ev_annual_energy(con, scenario, country, model_years_tuple) # noqa F841 + ev_annual_energy_raw = compute_ev_annual_energy(con, scenario, country, model_years_tuple) # noqa F841 + ev_annual_energy = con.sql( # noqa: F841 + """ + SELECT geography, model_year, sector, subsector, stride_annual_total, + 'ev' AS energy_source + FROM ev_annual_energy_raw + """ + ) # Combine: non-EV sectors + EV Transportation/Road stride_annual_energy = con.sql( # noqa: F841 """ - SELECT geography, model_year, sector, subsector, stride_annual_total + SELECT geography, model_year, sector, subsector, stride_annual_total, energy_source FROM stride_annual_energy_non_ev UNION ALL - SELECT geography, model_year, sector, subsector, stride_annual_total + SELECT geography, model_year, sector, subsector, stride_annual_total, energy_source FROM ev_annual_energy """ ) @@ -201,12 +211,13 @@ def compute_energy_projection_com_ind_tra_with_ev( """ ) - # Compute scaling factors (aggregate subsectors since load shapes are at sector level) + # Compute scaling factors per (sector, energy_source) stride_by_sector = con.sql( # noqa F841 """ - SELECT geography, model_year, sector, SUM(stride_annual_total) AS stride_annual_total + SELECT geography, model_year, sector, energy_source, + SUM(stride_annual_total) AS stride_annual_total FROM stride_annual_energy - GROUP BY geography, model_year, sector + GROUP BY geography, model_year, sector, energy_source """ ) @@ -216,6 +227,7 @@ def compute_energy_projection_com_ind_tra_with_ev( ls.geography ,ls.model_year ,ls.sector + ,stride.energy_source ,CASE WHEN ls.load_shape_annual_total > 0 THEN stride.stride_annual_total / ls.load_shape_annual_total @@ -230,6 +242,8 @@ def compute_energy_projection_com_ind_tra_with_ev( ) # Apply scaling factors to get final hourly projections + # Non-EV rows: keep original end-use metric + # EV rows: aggregate load shape across enduses, tag as 'ev_charging' return con.sql( f""" SELECT @@ -245,6 +259,26 @@ def compute_energy_projection_com_ind_tra_with_ev( ON ls.geography = sf.geography AND ls.model_year = sf.model_year AND ls.sector = sf.sector + WHERE sf.energy_source = 'base' + + UNION ALL + + SELECT + ls.timestamp + ,ls.model_year + ,'{scenario}' AS scenario + ,ls.geography + ,ls.sector + ,'ev_charging' AS metric + ,SUM(ls.adjusted_value) * sf.scaling_factor AS value + FROM ls_cit ls + JOIN scaling_factors sf + ON ls.geography = sf.geography + AND ls.model_year = sf.model_year + AND ls.sector = sf.sector + WHERE sf.energy_source = 'ev' + AND ls.sector = 'Transportation' + GROUP BY ls.timestamp, ls.model_year, ls.geography, ls.sector, sf.scaling_factor """ ) diff --git a/tests/test_project.py b/tests/test_project.py index c8bf446..3a46e93 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,4 +1,5 @@ from pathlib import Path +from unittest.mock import patch import pandas as pd import pytest @@ -616,6 +617,75 @@ def test_create_project_with_invalid_data_dir( assert "Dataset directory not found" in result.output +def test_create_project_registers_recent_project( + copy_project_input_data: tuple[Path, Path, Path], +) -> None: + """Test that 'stride projects create' calls add_recent_project for the viewer dropdown.""" + from unittest.mock import MagicMock + + tmp_path, _, project_config_file = copy_project_input_data + + mock_project = MagicMock() + mock_project.path = tmp_path / "test_project" + mock_project.config.project_id = "test_project" + + runner = CliRunner() + with ( + patch("stride.cli.stride.Project.create", return_value=mock_project), + patch("stride.ui.project_manager.add_recent_project") as mock_add, + ): + result = runner.invoke( + cli, + [ + "projects", + "create", + str(project_config_file), + "-d", + str(tmp_path), + "--dataset", + "global-test", + ], + ) + assert result.exit_code == 0, f"Project creation failed: {result.output}" + mock_add.assert_called_once_with(mock_project.path, "test_project") + + +def test_create_project_register_failure_does_not_crash( + copy_project_input_data: tuple[Path, Path, Path], +) -> None: + """Test that a failure in add_recent_project does not crash project creation.""" + from unittest.mock import MagicMock + + tmp_path, _, project_config_file = copy_project_input_data + + mock_project = MagicMock() + mock_project.path = tmp_path / "test_project" + mock_project.config.project_id = "test_project" + + runner = CliRunner() + with ( + patch("stride.cli.stride.Project.create", return_value=mock_project), + patch( + "stride.ui.project_manager.add_recent_project", + side_effect=OSError("Permission denied"), + ), + ): + result = runner.invoke( + cli, + [ + "projects", + "create", + str(project_config_file), + "-d", + str(tmp_path), + "--dataset", + "global-test", + ], + ) + # Command should still succeed despite the registration failure + assert result.exit_code == 0, f"Project creation failed: {result.output}" + + def test_projects_init_command(tmp_path: Path) -> None: """Test that 'stride projects init' creates a project template.""" output_file = tmp_path / "my_project.json5" @@ -954,3 +1024,93 @@ def test_create_project_with_env_var(copy_project_input_data: tuple[Path, Path, f"Expected dataset path '{expected_dataset_path}' not found in output. " f"Output was: {result.output}" ) + + +def test_scenario_dataset_override_loaded( + copy_project_input_data: tuple[Path, Path, Path], +) -> None: + """Verify that scenario dataset overrides are loaded through make_mapped_datasets. + + This tests datasets that go through dimension_mappings.json5 (e.g., + vehicle_per_capita_regressions), which is where the override substitution bug occurs. + Creates a scenario that overrides vehicle_per_capita_regressions with different + regression parameters and verifies the scenario table differs from baseline. + """ + tmp_path, project_input_dir, project_config_file = copy_project_input_data + + # Create an override CSV with different regression parameters. + # Baseline has a0_lin=0.025 for Region_A; override uses 0.999. + override_csv = project_input_dir / "alt_vehicle_per_capita.csv" + override_csv.write_text( + "region,metric,value\n" + "Region_A,a0_lin,0.999\n" + "Region_A,a1_lin,0.00025\n" + "Region_A,t0_lin,2024.0\n" + "Region_B,a0_lin,0.40\n" + "Region_B,a1_lin,0.004\n" + "Region_B,t0_lin,2024.0\n" + ) + + # Add a scenario that overrides vehicle_per_capita_regressions + config = load_json_file(project_config_file) + config["scenarios"].append( + { + "name": "high_vehicles", + "use_ev_projection": True, + "vehicle_per_capita_regressions": str(override_csv), + } + ) + dump_json_file(config, project_config_file) + + # Create the project + runner = CliRunner() + result = runner.invoke( + cli, + [ + "projects", + "create", + str(project_config_file), + "-d", + str(tmp_path), + "--dataset", + "global-test", + ], + ) + assert result.exit_code == 0, result.output + + project_dir = tmp_path / "test_project" + with Project.load(project_dir) as project: + con = project.con + + # Find the baseline and scenario tables for vehicle_per_capita_regressions + baseline_rows = con.sql( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'dsgrid_data' " + "AND table_name LIKE 'baseline__vehicle_per_capita_regressions__%'" + ).fetchall() + scenario_rows = con.sql( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'dsgrid_data' " + "AND table_name LIKE 'high_vehicles__vehicle_per_capita_regressions__%'" + ).fetchall() + assert len(baseline_rows) > 0, "Baseline vehicle_per_capita table not found" + assert len(scenario_rows) > 0, "Scenario vehicle_per_capita table not found" + + baseline_table = f"dsgrid_data.{baseline_rows[0][0]}" + scenario_table = f"dsgrid_data.{scenario_rows[0][0]}" + + baseline_df = con.sql(f"SELECT * FROM {baseline_table} ORDER BY value").df() + scenario_df = con.sql(f"SELECT * FROM {scenario_table} ORDER BY value").df() + + assert len(baseline_df) > 0, "Baseline table is empty" + assert len(scenario_df) > 0, "Scenario table is empty" + + # The override has a0_lin=0.999 for country_1 vs baseline 0.025. + # If the bug is present, both tables will have identical data. + baseline_values = sorted(baseline_df["value"].tolist()) + scenario_values = sorted(scenario_df["value"].tolist()) + assert baseline_values != scenario_values, ( + "Scenario vehicle_per_capita_regressions data is identical to baseline — " + "the override was not loaded. " + f"Baseline: {baseline_values}, Scenario: {scenario_values}" + )