Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ax/analysis/healthcheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis

__all__ = [
"create_healthcheck_analysis_card",
Expand All @@ -39,4 +40,5 @@
"ComplexityRatingAnalysis",
"PredictableMetricsAnalysis",
"BaselineImprovementAnalysis",
"TransferLearningAnalysis",
]
167 changes: 167 additions & 0 deletions ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest.mock import patch

from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
from ax.core.auxiliary import TransferLearningMetadata
from ax.core.experiment import Experiment
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase


def _make_experiment(
param_names: list[str],
experiment_type: str | None = None,
) -> Experiment:
"""Create a simple experiment with the given parameter names."""
return Experiment(
search_space=SearchSpace(
parameters=[
RangeParameter(
name=name,
parameter_type=ParameterType.FLOAT,
lower=0.0,
upper=1.0,
)
for name in param_names
]
),
name="test_experiment",
experiment_type=experiment_type,
)


_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments"


class TestTransferLearningAnalysis(TestCase):
def test_no_experiment_type_returns_pass(self) -> None:
"""When no experiment_type is set and no experiment_types provided,
return PASS."""
experiment = _make_experiment(["x1", "x2"], experiment_type=None)
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertTrue(card.is_passing())
self.assertIn("No experiment type set", card.subtitle)

@patch(_MOCK_TARGET, return_value={})
def test_no_candidates_returns_pass(self, mock_identify: object) -> None:
experiment = _make_experiment(["x1", "x2"], experiment_type="my_type")
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
self.assertTrue(card.is_passing())
self.assertTrue(card.df.empty)

@patch(_MOCK_TARGET)
def test_single_candidate_returns_warning(self, mock_identify: object) -> None:
experiment = _make_experiment(
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
)
mock_identify.return_value = { # pyre-ignore[16]
"source_exp": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3", "x4"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
self.assertFalse(card.is_passing())
self.assertIn("source_exp", card.subtitle)
self.assertIn("80.0%", card.subtitle)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp")
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0)

@patch(_MOCK_TARGET)
def test_multiple_candidates_sorted_by_count(self, mock_identify: object) -> None:
experiment = _make_experiment(
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
)
mock_identify.return_value = { # pyre-ignore[16]
"exp_low": TransferLearningMetadata(
overlap_parameters=["x1"],
),
"exp_high": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3", "x4"],
),
"exp_mid": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)

# Verify sorted descending by overlap count
self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high")
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid")
self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3)
self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low")
self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1)

# All experiments listed in subtitle
self.assertIn("exp_high", card.subtitle)
self.assertIn("exp_mid", card.subtitle)
self.assertIn("exp_low", card.subtitle)
self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle)

@patch(_MOCK_TARGET)
def test_percentage_calculation(self, mock_identify: object) -> None:
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
mock_identify.return_value = { # pyre-ignore[16]
"exp_a": TransferLearningMetadata(
overlap_parameters=["x1"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3)

@patch(_MOCK_TARGET)
def test_parameters_listed_alphabetically(self, mock_identify: object) -> None:
experiment = _make_experiment(
["alpha", "beta", "gamma", "delta"], experiment_type="my_type"
)
mock_identify.return_value = { # pyre-ignore[16]
"exp_a": TransferLearningMetadata(
overlap_parameters=["gamma", "alpha", "delta"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma")

def test_requires_experiment(self) -> None:
analysis = TransferLearningAnalysis()
with self.assertRaises(UserInputError):
analysis.compute(experiment=None)

@patch(_MOCK_TARGET)
def test_target_experiment_filtered_out(self, mock_identify: object) -> None:
"""The target experiment should be excluded from the results."""
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
mock_identify.return_value = { # pyre-ignore[16]
"test_experiment": TransferLearningMetadata(
overlap_parameters=["x1", "x2", "x3"],
),
"other_exp": TransferLearningMetadata(
overlap_parameters=["x1"],
),
}
analysis = TransferLearningAnalysis()
card = analysis.compute(experiment=experiment)
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
self.assertEqual(len(card.df), 1)
self.assertEqual(card.df.iloc[0]["Experiment"], "other_exp")
149 changes: 149 additions & 0 deletions ax/analysis/healthcheck/transfer_learning_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from typing import final, TYPE_CHECKING

import pandas as pd
from ax.adapter.base import Adapter
from ax.analysis.analysis import Analysis
from ax.analysis.healthcheck.healthcheck_analysis import (
create_healthcheck_analysis_card,
HealthcheckAnalysisCard,
HealthcheckStatus,
)
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
from pyre_extensions import override

if TYPE_CHECKING:
from ax.storage.sqa_store.sqa_config import SQAConfig


@final
class TransferLearningAnalysis(Analysis):
def __init__(
self,
experiment_types: list[str] | None = None,
overlap_threshold: float = 0.25,
max_num_exps: int = 10,
config: SQAConfig | None = None,
) -> None:
self.experiment_types = experiment_types
self.overlap_threshold = overlap_threshold
self.max_num_exps = max_num_exps
self.config = config

@override
def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategy | None = None,
adapter: Adapter | None = None,
) -> HealthcheckAnalysisCard:
if experiment is None:
raise UserInputError(
"TransferLearningAnalysis requires a non-null experiment to compute "
"overlap percentages. Please provide an experiment."
)

# Determine experiment types to query for.
experiment_types = self.experiment_types
if experiment_types is None:
if experiment.experiment_type is None:
return create_healthcheck_analysis_card(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle=(
"No experiment type set on this experiment. "
"Cannot search for transferable experiments."
),
df=pd.DataFrame(),
status=HealthcheckStatus.PASS,
)
experiment_types = [experiment.experiment_type]

# Lazy import to avoid circular dependency (sqa_store depends on
# healthcheck_analysis).
from ax.storage.sqa_store.load import identify_transferable_experiments

transferable_experiments = identify_transferable_experiments(
search_space=experiment.search_space,
experiment_types=experiment_types,
overlap_threshold=self.overlap_threshold,
max_num_exps=self.max_num_exps,
config=self.config,
)

# Filter out the target experiment itself from results.
transferable_experiments = {
name: metadata
for name, metadata in transferable_experiments.items()
if name != experiment.name
}

if not transferable_experiments:
return create_healthcheck_analysis_card(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle="No eligible source experiments found for transfer learning.",
df=pd.DataFrame(),
status=HealthcheckStatus.PASS,
)

total_parameters = len(experiment.search_space.parameters)

rows = []
for exp_name, metadata in transferable_experiments.items():
overlap_count = len(metadata.overlap_parameters)
overlap_pct = (
(overlap_count / total_parameters * 100)
if total_parameters > 0
else 0.0
)
rows.append(
{
"Experiment": exp_name,
"Overlapping Parameters": overlap_count,
"Overlap (%)": round(overlap_pct, 1),
"Parameters": ", ".join(sorted(metadata.overlap_parameters)),
}
)

# Sort by overlapping parameter count descending
rows.sort(key=lambda r: r["Overlapping Parameters"], reverse=True)

df = pd.DataFrame(rows)

n = len(rows)
exp_lines = "\n".join(
f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)"
for r in rows
)
subtitle = (
"Transfer learning can improve optimization by leveraging data "
"from similar past experiments. We found "
f"**{n} eligible source experiment(s)** "
"for transfer learning:\n\n"
f"{exp_lines}\n\n"
"Caution: Only use source experiments that are closely related "
"to your current experiment. "
"Using data from unrelated experiments can lead to negative "
"transfer, which may hurt "
"optimization performance. Review the overlapping parameters "
"before enabling transfer learning."
)

return create_healthcheck_analysis_card(
name=self.__class__.__name__,
title="Transfer Learning Eligibility",
subtitle=subtitle,
df=df,
status=HealthcheckStatus.WARNING,
)
4 changes: 4 additions & 0 deletions ax/analysis/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
from ax.analysis.insights import InsightsAnalysis
from ax.analysis.results import ResultsAnalysis
from ax.analysis.trials import AllTrialsAnalysis
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
options: OrchestratorOptions | None = None,
tier_metadata: dict[str, Any] | None = None,
model_fit_threshold: float | None = None,
sqa_config: Any = None,
) -> None:
super().__init__()
self.can_generate = can_generate
Expand All @@ -124,6 +126,7 @@ def __init__(
self.options = options
self.tier_metadata = tier_metadata
self.model_fit_threshold = model_fit_threshold
self.sqa_config = sqa_config

@override
def validate_applicable_state(
Expand Down Expand Up @@ -229,6 +232,7 @@ def compute(
if not has_batch_trials
else None,
BaselineImprovementAnalysis() if not has_batch_trials else None,
TransferLearningAnalysis(config=self.sqa_config),
*[
SearchSpaceAnalysis(trial_index=trial.index)
for trial in candidate_trials
Expand Down
Loading