diff --git a/ax/analysis/healthcheck/__init__.py b/ax/analysis/healthcheck/__init__.py index 87f819e7436..c51edcf57fd 100644 --- a/ax/analysis/healthcheck/__init__.py +++ b/ax/analysis/healthcheck/__init__.py @@ -24,6 +24,7 @@ from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis __all__ = [ "create_healthcheck_analysis_card", @@ -39,4 +40,5 @@ "ComplexityRatingAnalysis", "PredictableMetricsAnalysis", "BaselineImprovementAnalysis", + "TransferLearningAnalysis", ] diff --git a/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py new file mode 100644 index 00000000000..6382c5b2993 --- /dev/null +++ b/ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest.mock import patch + +from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis +from ax.core.auxiliary import TransferLearningMetadata +from ax.core.experiment import Experiment +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import UserInputError +from ax.utils.common.testutils import TestCase + + +def _make_experiment( + param_names: list[str], + experiment_type: str | None = None, +) -> Experiment: + """Create a simple experiment with the given parameter names.""" + return Experiment( + search_space=SearchSpace( + parameters=[ + RangeParameter( + name=name, + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + ) + for name in param_names + ] + ), + name="test_experiment", + experiment_type=experiment_type, + ) + + +_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments" + + +class TestTransferLearningAnalysis(TestCase): + def test_no_experiment_type_returns_pass(self) -> None: + """When no experiment_type is set and no experiment_types provided, + return PASS.""" + experiment = _make_experiment(["x1", "x2"], experiment_type=None) + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.PASS) + self.assertTrue(card.is_passing()) + self.assertIn("No experiment type set", card.subtitle) + + @patch(_MOCK_TARGET, return_value={}) + def test_no_candidates_returns_pass(self, mock_identify: object) -> None: + experiment = _make_experiment(["x1", "x2"], experiment_type="my_type") + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.PASS) + self.assertTrue(card.is_passing()) + self.assertTrue(card.df.empty) + + @patch(_MOCK_TARGET) + def test_single_candidate_returns_warning(self, mock_identify: object) -> None: + experiment = _make_experiment( + ["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type" + ) + mock_identify.return_value = { # pyre-ignore[16] + "source_exp": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3", "x4"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + self.assertFalse(card.is_passing()) + self.assertIn("source_exp", card.subtitle) + self.assertIn("80.0%", card.subtitle) + self.assertEqual(len(card.df), 1) + self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp") + self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4) + self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0) + + @patch(_MOCK_TARGET) + def test_multiple_candidates_sorted_by_count(self, mock_identify: object) -> None: + experiment = _make_experiment( + ["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type" + ) + mock_identify.return_value = { # pyre-ignore[16] + "exp_low": TransferLearningMetadata( + overlap_parameters=["x1"], + ), + "exp_high": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3", "x4"], + ), + "exp_mid": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + + # Verify sorted descending by overlap count + self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high") + self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4) + self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid") + self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3) + self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low") + self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1) + + # All experiments listed in subtitle + self.assertIn("exp_high", card.subtitle) + self.assertIn("exp_mid", card.subtitle) + self.assertIn("exp_low", card.subtitle) + self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle) + + @patch(_MOCK_TARGET) + def test_percentage_calculation(self, mock_identify: object) -> None: + experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") + mock_identify.return_value = { # pyre-ignore[16] + "exp_a": TransferLearningMetadata( + overlap_parameters=["x1"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3) + + @patch(_MOCK_TARGET) + def test_parameters_listed_alphabetically(self, mock_identify: object) -> None: + experiment = _make_experiment( + ["alpha", "beta", "gamma", "delta"], experiment_type="my_type" + ) + mock_identify.return_value = { # pyre-ignore[16] + "exp_a": TransferLearningMetadata( + overlap_parameters=["gamma", "alpha", "delta"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma") + + def test_requires_experiment(self) -> None: + analysis = TransferLearningAnalysis() + with self.assertRaises(UserInputError): + analysis.compute(experiment=None) + + @patch(_MOCK_TARGET) + def test_target_experiment_filtered_out(self, mock_identify: object) -> None: + """The target experiment should be excluded from the results.""" + experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type") + mock_identify.return_value = { # pyre-ignore[16] + "test_experiment": TransferLearningMetadata( + overlap_parameters=["x1", "x2", "x3"], + ), + "other_exp": TransferLearningMetadata( + overlap_parameters=["x1"], + ), + } + analysis = TransferLearningAnalysis() + card = analysis.compute(experiment=experiment) + self.assertEqual(card.get_status(), HealthcheckStatus.WARNING) + self.assertEqual(len(card.df), 1) + self.assertEqual(card.df.iloc[0]["Experiment"], "other_exp") diff --git a/ax/analysis/healthcheck/transfer_learning_analysis.py b/ax/analysis/healthcheck/transfer_learning_analysis.py new file mode 100644 index 00000000000..b60460e6bd7 --- /dev/null +++ b/ax/analysis/healthcheck/transfer_learning_analysis.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from typing import final, TYPE_CHECKING + +import pandas as pd +from ax.adapter.base import Adapter +from ax.analysis.analysis import Analysis +from ax.analysis.healthcheck.healthcheck_analysis import ( + create_healthcheck_analysis_card, + HealthcheckAnalysisCard, + HealthcheckStatus, +) +from ax.core.experiment import Experiment +from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy +from pyre_extensions import override + +if TYPE_CHECKING: + from ax.storage.sqa_store.sqa_config import SQAConfig + + +@final +class TransferLearningAnalysis(Analysis): + def __init__( + self, + experiment_types: list[str] | None = None, + overlap_threshold: float = 0.25, + max_num_exps: int = 10, + config: SQAConfig | None = None, + ) -> None: + self.experiment_types = experiment_types + self.overlap_threshold = overlap_threshold + self.max_num_exps = max_num_exps + self.config = config + + @override + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategy | None = None, + adapter: Adapter | None = None, + ) -> HealthcheckAnalysisCard: + if experiment is None: + raise UserInputError( + "TransferLearningAnalysis requires a non-null experiment to compute " + "overlap percentages. Please provide an experiment." + ) + + # Determine experiment types to query for. + experiment_types = self.experiment_types + if experiment_types is None: + if experiment.experiment_type is None: + return create_healthcheck_analysis_card( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle=( + "No experiment type set on this experiment. " + "Cannot search for transferable experiments." + ), + df=pd.DataFrame(), + status=HealthcheckStatus.PASS, + ) + experiment_types = [experiment.experiment_type] + + # Lazy import to avoid circular dependency (sqa_store depends on + # healthcheck_analysis). + from ax.storage.sqa_store.load import identify_transferable_experiments + + transferable_experiments = identify_transferable_experiments( + search_space=experiment.search_space, + experiment_types=experiment_types, + overlap_threshold=self.overlap_threshold, + max_num_exps=self.max_num_exps, + config=self.config, + ) + + # Filter out the target experiment itself from results. + transferable_experiments = { + name: metadata + for name, metadata in transferable_experiments.items() + if name != experiment.name + } + + if not transferable_experiments: + return create_healthcheck_analysis_card( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle="No eligible source experiments found for transfer learning.", + df=pd.DataFrame(), + status=HealthcheckStatus.PASS, + ) + + total_parameters = len(experiment.search_space.parameters) + + rows = [] + for exp_name, metadata in transferable_experiments.items(): + overlap_count = len(metadata.overlap_parameters) + overlap_pct = ( + (overlap_count / total_parameters * 100) + if total_parameters > 0 + else 0.0 + ) + rows.append( + { + "Experiment": exp_name, + "Overlapping Parameters": overlap_count, + "Overlap (%)": round(overlap_pct, 1), + "Parameters": ", ".join(sorted(metadata.overlap_parameters)), + } + ) + + # Sort by overlapping parameter count descending + rows.sort(key=lambda r: r["Overlapping Parameters"], reverse=True) + + df = pd.DataFrame(rows) + + n = len(rows) + exp_lines = "\n".join( + f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)" + for r in rows + ) + subtitle = ( + "Transfer learning can improve optimization by leveraging data " + "from similar past experiments. We found " + f"**{n} eligible source experiment(s)** " + "for transfer learning:\n\n" + f"{exp_lines}\n\n" + "Caution: Only use source experiments that are closely related " + "to your current experiment. " + "Using data from unrelated experiments can lead to negative " + "transfer, which may hurt " + "optimization performance. Review the overlapping parameters " + "before enabling transfer learning." + ) + + return create_healthcheck_analysis_card( + name=self.__class__.__name__, + title="Transfer Learning Eligibility", + subtitle=subtitle, + df=df, + status=HealthcheckStatus.WARNING, + ) diff --git a/ax/analysis/overview.py b/ax/analysis/overview.py index 75d87f2afc4..ecdcd53dee1 100644 --- a/ax/analysis/overview.py +++ b/ax/analysis/overview.py @@ -24,6 +24,7 @@ from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates +from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis from ax.analysis.insights import InsightsAnalysis from ax.analysis.results import ResultsAnalysis from ax.analysis.trials import AllTrialsAnalysis @@ -114,6 +115,7 @@ def __init__( options: OrchestratorOptions | None = None, tier_metadata: dict[str, Any] | None = None, model_fit_threshold: float | None = None, + sqa_config: Any = None, ) -> None: super().__init__() self.can_generate = can_generate @@ -124,6 +126,7 @@ def __init__( self.options = options self.tier_metadata = tier_metadata self.model_fit_threshold = model_fit_threshold + self.sqa_config = sqa_config @override def validate_applicable_state( @@ -229,6 +232,7 @@ def compute( if not has_batch_trials else None, BaselineImprovementAnalysis() if not has_batch_trials else None, + TransferLearningAnalysis(config=self.sqa_config), *[ SearchSpaceAnalysis(trial_index=trial.index) for trial in candidate_trials