From 9e25055164718e8adc520edf668f8709e75b6cb7 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 16:03:52 -0700 Subject: [PATCH 1/5] Fix prior deserialization for priors with buffered attributes (#5167) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5167 The Ax JSON decoder's `botorch_component_from_json` strips the `BUFFERED_PREFIX` from state_dict keys only for `TransformedDistribution` subclasses. This misses priors like `BetaPrior` whose underlying distribution (`Beta`) uses `property` descriptors delegating to an internal `Dirichlet`, causing `_bufferize_attributes` to use the prefix. Broaden the check from `TransformedDistribution` to `(TransformedDistribution, Prior)` so all gpytorch priors with buffered attributes deserialize correctly. Differential Revision: D100341242 Reviewed By: sdaulton --- ax/storage/json_store/decoders.py | 9 ++++--- .../json_store/tests/test_json_store.py | 24 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index accf67e599f..b31be6bb617 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -54,9 +54,9 @@ from botorch.models.transforms.input import ChainedInputTransform, InputTransform from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform from botorch.utils.types import _DefaultType, DEFAULT +from gpytorch.priors import Prior from gpytorch.priors.utils import BUFFERED_PREFIX from pyre_extensions import assert_is_instance -from torch.distributions.transformed_distribution import TransformedDistribution logger: logging.Logger = get_logger(__name__) @@ -369,8 +369,11 @@ def botorch_component_from_json(botorch_class: type[T], json: dict[str, Any]) -> for k, v in state_dict.items() } ) - if issubclass(botorch_class, TransformedDistribution): - # Extract the buffered attributes for transformed priors. + if issubclass(botorch_class, Prior): + # Extract the buffered attributes for priors. Some priors (e.g. + # BetaPrior, LogNormalPrior) store parameters with BUFFERED_PREFIX + # because their underlying distribution uses @property descriptors + # that cannot be deleted by _bufferize_attributes. for k in list(state_dict.keys()): if k.startswith(BUFFERED_PREFIX): state_dict[k[len(BUFFERED_PREFIX) :]] = state_dict.pop(k) diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 47e23cda9ff..d6866fc4521 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -1067,6 +1067,30 @@ def test_BadStateDict(self) -> None: del expected_json["state_dict"]["lower_bound"] botorch_component_from_json(interval.__class__, expected_json) + def test_prior_roundtrip_serialization(self) -> None: + """Test encode/decode roundtrip for priors with buffered attributes. + + Priors whose underlying distribution uses @property descriptors + (e.g. BetaPrior via Dirichlet, LogNormalPrior via TransformedDistribution) + store state_dict keys with BUFFERED_PREFIX. The decoder must strip + the prefix to match __init__ arg names. + """ + from botorch.models.utils.priors import BetaPrior + from gpytorch.priors.torch_priors import GammaPrior, LogNormalPrior, NormalPrior + + priors = [ + ("BetaPrior", BetaPrior(concentration1=2.5, concentration0=1.5)), + ("GammaPrior", GammaPrior(concentration=2.0, rate=1.0)), + ("NormalPrior", NormalPrior(loc=0.0, scale=1.0)), + ("LogNormalPrior", LogNormalPrior(loc=0.0, scale=1.0)), + ] + for name, prior in priors: + with self.subTest(prior=name): + encoded = botorch_component_to_dict(prior) + decoded = botorch_component_from_json(prior.__class__, encoded) + self.assertIsInstance(decoded, prior.__class__) + self.assertEqual(decoded.state_dict(), prior.state_dict()) + def test_observation_features_backward_compatibility(self) -> None: json = { "__type": "ObservationFeatures", From 31fd63ad58190b52cfcc9a57ca0201fc4bf495a9 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 16:03:52 -0700 Subject: [PATCH 2/5] Add input_transform_argparse dispatcher for LearnedFeatureImputation (#5106) Summary: X-link: https://github.com/facebookexternal/botorch_fb/pull/34 Pull Request resolved: https://github.com/facebook/Ax/pull/5106 Wire LearnedFeatureImputation and ImputedMultiTaskGP into Ax: 1. **input_transform_argparse dispatcher**: Computes `feature_indices` and `d` from a heterogeneous MultiTaskDataset using target-first feature ordering. Validates that the dataset is a MultiTaskDataset with heterogeneous features. 2. **Storage registry**: Register ImputedMultiTaskGP in MODEL_REGISTRY and LearnedFeatureImputation in INPUT_TRANSFORM_REGISTRY. 3. **Model selection (utils.py)**: When a heterogeneous MultiTaskDataset is detected and a model class is specified (e.g. ImputedMultiTaskGP), use the specified class instead of force-overriding to HeterogeneousMTGP. Also add automatic Normalize + LearnedFeatureImputation transform chaining for ImputedMultiTaskGP. Differential Revision: D97625733 --- .../input_constructors/input_transforms.py | 77 +++++++++ .../tests/test_input_transform_argparse.py | 161 ++++++++++++++++++ ax/storage/botorch_modular_registry.py | 2 + 3 files changed, 240 insertions(+) diff --git a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py index c2e427fc8b1..0f6a270048a 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py @@ -17,6 +17,7 @@ from botorch.models.transforms.input import ( FilterFeatures, InputTransform, + LearnedFeatureImputation, Normalize, Warp, ) @@ -314,3 +315,79 @@ def _input_transform_argparse_filter_features( ) return input_transform_options_copy + + +@input_transform_argparse.register(LearnedFeatureImputation) +def _input_transform_argparse_learned_feature_imputation( + input_transform_class: type[LearnedFeatureImputation], + dataset: SupervisedDataset, + search_space_digest: SearchSpaceDigest, + input_transform_options: dict[str, Any] | None = None, + torch_device: torch.device | None = None, + torch_dtype: torch.dtype | None = None, +) -> dict[str, Any]: + """Extract LearnedFeatureImputation kwargs from a MultiTaskDataset. + + Computes ``feature_indices`` and ``d`` from the heterogeneous feature sets + in the dataset, following the same convention as + ``HeterogeneousMTGP.construct_inputs``: the target task is placed first, + and the task feature column is excluded from the feature union. + + Args: + input_transform_class: Input transform class. + dataset: A ``MultiTaskDataset`` with heterogeneous features. + search_space_digest: Search space digest. + input_transform_options: Optional overrides for transform kwargs. + torch_device: Device for the transform parameters. + torch_dtype: Dtype for the transform parameters. + + Returns: + A dictionary with ``feature_indices``, ``d``, ``task_feature_index``, + ``device``, and ``dtype`` keys. + """ + if not isinstance(dataset, MultiTaskDataset): + raise ValueError( + "LearnedFeatureImputation requires a MultiTaskDataset, " + f"got {type(dataset).__name__}." + ) + if not dataset.has_heterogeneous_features: + raise ValueError( + "LearnedFeatureImputation requires a MultiTaskDataset with " + "heterogeneous features (has_heterogeneous_features=True)." + ) + input_transform_options = input_transform_options or {} + + # Order datasets: target first, then remaining (same as HeterogeneousMTGP). + child_datasets = dataset.datasets.copy() + target_dataset = child_datasets.pop(dataset.target_outcome_name) + all_datasets = [target_dataset] + list(child_datasets.values()) + + # Use target's feature order as canonical (NO alphabetical sort). + # Source-only features are appended at the end. + task_feature_index = ( + dataset.task_feature_index if (dataset.task_feature_index is not None) else -1 + ) + all_features: list[str] = list(target_dataset.feature_names[:task_feature_index]) + for ds in all_datasets[1:]: + for fn in ds.feature_names[:task_feature_index]: + if fn not in all_features: + all_features.append(fn) + d = len(all_features) + + # Map each task's features to indices in the global feature space. + feature_indices = { + task_idx: [ + all_features.index(fn) for fn in ds.feature_names[:task_feature_index] + ] + for task_idx, ds in enumerate(all_datasets) + } + + kwargs: dict[str, Any] = { + "feature_indices": feature_indices, + "d": d, + "task_feature_index": task_feature_index, + "device": torch_device, + "dtype": torch_dtype or torch.float64, + } + kwargs.update(input_transform_options) + return kwargs diff --git a/ax/generators/torch/tests/test_input_transform_argparse.py b/ax/generators/torch/tests/test_input_transform_argparse.py index 133bd2a1770..07156aa1f22 100644 --- a/ax/generators/torch/tests/test_input_transform_argparse.py +++ b/ax/generators/torch/tests/test_input_transform_argparse.py @@ -22,6 +22,7 @@ FilterFeatures, InputStandardize, InputTransform, + LearnedFeatureImputation, Normalize, Warp, ) @@ -376,3 +377,163 @@ def test_argparse_filter_features(self) -> None: "ignored_params": ["x0", "x1"], }, ) + + def test_argparse_learned_feature_imputation(self) -> None: + task_feature_name = Keys.TASK_FEATURE_NAME.value + + # Task 0 (target): features x0, x1, x2, x3 + task + dataset_target = SupervisedDataset( + X=torch.cat([torch.rand(5, 4), torch.zeros(5, 1)], dim=-1), + Y=torch.randn(5, 1), + feature_names=["x0", "x1", "x2", "x3", task_feature_name], + outcome_names=["y0"], + ) + # Task 1 (aux): features x0, x1 + task (subset of target features) + dataset_aux = SupervisedDataset( + X=torch.cat([torch.rand(5, 2), torch.ones(5, 1)], dim=-1), + Y=torch.randn(5, 1), + feature_names=["x0", "x1", task_feature_name], + outcome_names=["y1"], + ) + mtds = MultiTaskDataset( + datasets=[dataset_target, dataset_aux], + target_outcome_name="y0", + task_feature_index=-1, + ) + mt_ssd = dataclasses.replace( + self.search_space_digest, + feature_names=["x0", "x1", "x2", "x3", task_feature_name], + task_features=[-1], + bounds=[(0.0, 1.0)] * 5, + ) + + kwargs = input_transform_argparse( + LearnedFeatureImputation, + dataset=mtds, + search_space_digest=mt_ssd, + ) + + # all_features uses target-first ordering: ["x0","x1","x2","x3"] + self.assertEqual(kwargs["d"], 4) + # Target has all 4 features, aux has x0, x1 -> indices [0, 1] + self.assertEqual(kwargs["feature_indices"], {0: [0, 1, 2, 3], 1: [0, 1]}) + self.assertEqual(kwargs["task_feature_index"], -1) + self.assertEqual(kwargs["dtype"], torch.float64) + + with self.subTest("non_multitask_dataset_raises"): + with self.assertRaisesRegex(ValueError, "requires a MultiTaskDataset"): + input_transform_argparse( + LearnedFeatureImputation, + dataset=self.dataset, + search_space_digest=self.search_space_digest, + ) + + with self.subTest("homogeneous_features_raises"): + homogeneous_ds = MultiTaskDataset( + datasets=[ + SupervisedDataset( + X=torch.rand(5, 3), + Y=torch.randn(5, 1), + feature_names=["x0", "x1", task_feature_name], + outcome_names=["y0"], + ), + SupervisedDataset( + X=torch.rand(5, 3), + Y=torch.randn(5, 1), + feature_names=["x0", "x1", task_feature_name], + outcome_names=["y1"], + ), + ], + target_outcome_name="y0", + task_feature_index=-1, + ) + with self.assertRaisesRegex(ValueError, "heterogeneous features"): + input_transform_argparse( + LearnedFeatureImputation, + dataset=homogeneous_ds, + search_space_digest=mt_ssd, + ) + + with self.subTest("options_override"): + kwargs = input_transform_argparse( + LearnedFeatureImputation, + dataset=mtds, + search_space_digest=mt_ssd, + torch_dtype=torch.float32, + ) + self.assertEqual(kwargs["dtype"], torch.float32) + + def test_argparse_learned_feature_imputation_feature_ordering(self) -> None: + """Test that feature ordering preserves target's order, not alphabetical.""" + task_feature_name = Keys.TASK_FEATURE_NAME.value + + with self.subTest("target_order_preserved_not_alphabetical"): + # Target: features C, A, B (NOT alphabetical) + target_ds = SupervisedDataset( + X=torch.cat([torch.rand(3, 3), torch.zeros(3, 1)], dim=-1), + Y=torch.randn(3, 1), + feature_names=["C", "A", "B", task_feature_name], + outcome_names=["target"], + ) + # Source: features A, B (subset, different order) + source_ds = SupervisedDataset( + X=torch.cat([torch.rand(2, 2), torch.ones(2, 1)], dim=-1), + Y=torch.randn(2, 1), + feature_names=["A", "B", task_feature_name], + outcome_names=["source"], + ) + mtds = MultiTaskDataset( + datasets=[target_ds, source_ds], + target_outcome_name="target", + task_feature_index=-1, + ) + ssd = dataclasses.replace( + self.search_space_digest, + feature_names=["C", "A", "B", task_feature_name], + task_features=[-1], + bounds=[(0.0, 1.0)] * 4, + ) + kwargs = input_transform_argparse( + LearnedFeatureImputation, + dataset=mtds, + search_space_digest=ssd, + ) + # Canonical order should be C, A, B (target's order), not A, B, C + self.assertEqual(kwargs["d"], 3) + # Target: C, A, B -> [0, 1, 2]; Source: A, B -> [1, 2] + self.assertEqual(kwargs["feature_indices"], {0: [0, 1, 2], 1: [1, 2]}) + + with self.subTest("source_only_features_appended_at_end"): + # Target: A, B; Source: B, C, D (C, D are source-only) + target_ds = SupervisedDataset( + X=torch.cat([torch.rand(3, 2), torch.zeros(3, 1)], dim=-1), + Y=torch.randn(3, 1), + feature_names=["A", "B", task_feature_name], + outcome_names=["target"], + ) + source_ds = SupervisedDataset( + X=torch.cat([torch.rand(2, 3), torch.ones(2, 1)], dim=-1), + Y=torch.randn(2, 1), + feature_names=["B", "C", "D", task_feature_name], + outcome_names=["source"], + ) + mtds = MultiTaskDataset( + datasets=[target_ds, source_ds], + target_outcome_name="target", + task_feature_index=-1, + ) + ssd = dataclasses.replace( + self.search_space_digest, + feature_names=["A", "B", "C", "D", task_feature_name], + task_features=[-1], + bounds=[(0.0, 1.0)] * 5, + ) + kwargs = input_transform_argparse( + LearnedFeatureImputation, + dataset=mtds, + search_space_digest=ssd, + ) + # Canonical order: A, B (target), then C, D (source-only, appended) + self.assertEqual(kwargs["d"], 4) + # Target: A, B -> [0, 1]; Source: B, C, D -> [1, 2, 3] + self.assertEqual(kwargs["feature_indices"], {0: [0, 1], 1: [1, 2, 3]}) diff --git a/ax/storage/botorch_modular_registry.py b/ax/storage/botorch_modular_registry.py index d5a9a9c7d7b..6328349c8f5 100644 --- a/ax/storage/botorch_modular_registry.py +++ b/ax/storage/botorch_modular_registry.py @@ -86,6 +86,7 @@ FilterFeatures, InputPerturbation, InputTransform, + LearnedFeatureImputation, Normalize, Round, Warp, @@ -215,6 +216,7 @@ """ INPUT_TRANSFORM_REGISTRY: dict[type[InputTransform], str] = { ChainedInputTransform: "ChainedInputTransform", + LearnedFeatureImputation: "LearnedFeatureImputation", Normalize: "Normalize", Round: "Round", Warp: "Warp", From cae96cf26fc276b4cd8dffc22cacb91d47fce1af Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 16:03:52 -0700 Subject: [PATCH 3/5] Don't override explicitly requested model class in get_transfer_learning_gs (#5183) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5183 When the source and target search spaces are incompatible, any model class is silently overridden to `HeterogeneousMTGP`. This prevents using other models at all. Now, we only override to `HeterogeneousMTGP` in get_transfer_learning_gs, instead of in two different places in the stack. Thus, we will still have the same functionality but not double-force it. Differential Revision: D101174566 --- ax/generators/torch/botorch_modular/utils.py | 22 ++-- ax/generators/torch/tests/test_utils.py | 126 ++----------------- 2 files changed, 19 insertions(+), 129 deletions(-) diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index e645a994370..e85c27abcb8 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -274,8 +274,7 @@ def choose_model_class( dataset: The dataset on which the model will be fitted. search_space_digest: The digest of the search space the model will be fitted within. - specified_model_class: If provided, this model class will be used unless - overridden for specific cases (e.g., heterogeneous datasets). + specified_model_class: If provided, this model class will be used. Returns: A BoTorch `Model` class. @@ -295,23 +294,18 @@ def choose_model_class( "Multi-task multi-fidelity optimization not yet supported." ) - # Check for heterogeneous multi-task datasets & override model class if needed. + # Check for heterogeneous multi-task datasets. If a model class was + # explicitly specified, respect it; otherwise default to HeterogeneousMTGP. if ( search_space_digest.task_features and isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features ): - if ( - specified_model_class is not None - and specified_model_class is not HeterogeneousMTGP - ): - logger.warning( - f"Detected heterogeneous features in MultiTaskDataset. " - f"Overriding specified model class {specified_model_class.__name__} " - f"with HeterogeneousMTGP for transfer learning with " - f"heterogeneous search spaces." - ) - model_class = HeterogeneousMTGP + model_class = ( + specified_model_class + if specified_model_class is not None + else HeterogeneousMTGP + ) logger.debug(f"Chose BoTorch model class: {model_class}.") return model_class diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index 44a3612ddbf..e8269880aad 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -27,7 +27,6 @@ choose_model_class, construct_acquisition_and_optimizer_options, convert_to_block_design, - copy_model_config_with_default_values, get_cv_fold, logger, ModelConfig, @@ -70,7 +69,6 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP -from botorch.models.transforms.input import Normalize, Warp from botorch.posteriors.ensemble import EnsemblePosterior from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.types import DEFAULT @@ -182,126 +180,24 @@ def test_choose_model_class_task_features(self) -> None: ), ) - def test_choose_model_class_heterogeneous_task_features(self) -> None: - # Test that HeterogeneousMTGP is chosen when MultiTaskDataset has - # heterogeneous features. + def test_choose_model_class_heterogeneous_respects_specified(self) -> None: mt_dataset = self._get_heterogeneous_mt_dataset() + ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) - # Execute: Choose model class with task features - model_class = choose_model_class( - dataset=mt_dataset, - search_space_digest=dataclasses.replace( - self.search_space_digest, task_features=[-1] - ), - ) - - # Assert: Should select HeterogeneousMTGP for heterogeneous features - self.assertEqual(HeterogeneousMTGP, model_class) - - def test_choose_model_class_heterogeneous_overrides_specified(self) -> None: - # Test that HeterogeneousMTGP overrides a pre-specified model class - # when heterogeneous features are detected - mt_dataset = self._get_heterogeneous_mt_dataset() - - # Execute: Try to specify MultiTaskGP explicitly - model_class = choose_model_class( - dataset=mt_dataset, - search_space_digest=dataclasses.replace( - self.search_space_digest, task_features=[-1] - ), - specified_model_class=MultiTaskGP, - ) - - # Assert: Should override to HeterogeneousMTGP despite specification - self.assertEqual(HeterogeneousMTGP, model_class) - - def test_choose_model_class_respects_specified_when_no_override_needed( - self, - ) -> None: - # Test that specified_model_class is used when no override is needed - model_class = choose_model_class( - dataset=self.supervised_dataset, - search_space_digest=self.search_space_digest, - specified_model_class=SingleTaskGP, - ) - self.assertEqual(SingleTaskGP, model_class) - - # Test with a different specified class - model_class = choose_model_class( - dataset=self.supervised_dataset, - search_space_digest=self.search_space_digest, - specified_model_class=MixedSingleTaskGP, - ) - self.assertEqual(MixedSingleTaskGP, model_class) - - def test_copy_model_config_adds_normalize_for_heterogeneous_mtgp(self) -> None: - # Test that Normalize input transform is added for HeterogeneousMTGP - mt_dataset = self._get_heterogeneous_mt_dataset() - - # Case 1: No input transform classes specified - model_config = ModelConfig() - updated_config = copy_model_config_with_default_values( - model_config=model_config, - dataset=mt_dataset, - search_space_digest=dataclasses.replace( - self.search_space_digest, task_features=[-1] - ), - ) - self.assertEqual(updated_config.botorch_model_class, HeterogeneousMTGP) - self.assertEqual(updated_config.input_transform_classes, [Normalize]) + # Without specified class, defaults to HeterogeneousMTGP self.assertEqual( - none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, - ) - - # Case 2: Input transform classes already specified (but not Normalize) - model_config = ModelConfig( - input_transform_classes=[Warp], input_transform_options={"Warp": {}} - ) - updated_config = copy_model_config_with_default_values( - model_config=model_config, - dataset=mt_dataset, - search_space_digest=dataclasses.replace( - self.search_space_digest, task_features=[-1] - ), + HeterogeneousMTGP, + choose_model_class(dataset=mt_dataset, search_space_digest=ssd), ) - self.assertEqual(updated_config.input_transform_classes, [Warp, Normalize]) + # With specified class, respects it self.assertEqual( - none_throws(updated_config.input_transform_options), - {"Warp": {}, "Normalize": {"bounds": None}}, - ) - - # Case 3: Normalize already in input transform classes - model_config = ModelConfig( - input_transform_classes=[Normalize], - input_transform_options={"Normalize": {"bounds": None}}, - ) - updated_config = copy_model_config_with_default_values( - model_config=model_config, - dataset=mt_dataset, - search_space_digest=dataclasses.replace( - self.search_space_digest, task_features=[-1] + MultiTaskGP, + choose_model_class( + dataset=mt_dataset, + search_space_digest=ssd, + specified_model_class=MultiTaskGP, ), ) - self.assertEqual(updated_config.input_transform_classes, [Normalize]) - self.assertEqual( - none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, - ) - - def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None: - # Test that Normalize is NOT added for non-HeterogeneousMTGP models - model_config = ModelConfig() - updated_config = copy_model_config_with_default_values( - model_config=model_config, - dataset=self.supervised_dataset, - search_space_digest=self.search_space_digest, - ) - # Should be SingleTaskGP, not HeterogeneousMTGP - self.assertEqual(updated_config.botorch_model_class, SingleTaskGP) - # Should not have added Normalize - self.assertEqual(updated_config.input_transform_classes, DEFAULT) - self.assertEqual(updated_config.input_transform_options, {}) def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP. From 64c3516fdb52b57a87b3ce8bf8365e06a2e8763f Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 16:03:52 -0700 Subject: [PATCH 4/5] Use search space bounds for Normalize in transfer learning adapter (#5184) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5184 The transfer learning adapter explicitly passed `bounds=None` to Normalize, forcing `learn_bounds=True`. This caused Normalize bounds to be learned from data instead of fixed to the search space, resulting in bounds that drift during training and differ between benchmark configs despite identical search spaces. Remove the `bounds=None` override so that `_set_default_bounds` provides the correct search space bounds from the SearchSpaceDigest. Differential Revision: D100669010 --- ax/adapter/transfer_learning/adapter.py | 78 +++++++++++++++++-- .../torch/botorch_modular/surrogate.py | 5 +- 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/ax/adapter/transfer_learning/adapter.py b/ax/adapter/transfer_learning/adapter.py index d962bc14143..ce060bac8cb 100644 --- a/ax/adapter/transfer_learning/adapter.py +++ b/ax/adapter/transfer_learning/adapter.py @@ -7,6 +7,7 @@ from __future__ import annotations +import dataclasses import warnings from collections.abc import Mapping, Sequence from logging import Logger @@ -38,7 +39,7 @@ from ax.core.observation import ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import FixedParameter, RangeParameter -from ax.core.search_space import SearchSpace +from ax.core.search_space import SearchSpace, SearchSpaceDigest from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.generation_strategy.best_model_selector import ( ReductionCriterion, @@ -504,6 +505,56 @@ def _get_task_datasets( ) return task_datasets + def _expand_ssd_to_joint_space( + self, + search_space_digest: SearchSpaceDigest, + ) -> SearchSpaceDigest: + """Expand SSD bounds and feature_names to cover the joint search space. + + The SSD produced by ``_get_fit_args`` reflects the target search space. + When source experiments have additional parameters, the model operates + in the full joint feature space (via ``ImputedMultiTaskGP`` / + ``HeterogeneousMTGP``). This method appends bounds and feature names + for source-only parameters so that input transforms (Normalize, + LearnedFeatureImputation) receive correct full-space bounds. + """ + existing_names = set(search_space_digest.feature_names) + extra_names: list[str] = [] + extra_bounds: list[tuple[int | float, int | float]] = [] + for name, param in self.joint_search_space.parameters.items(): + if name not in existing_names and isinstance(param, RangeParameter): + extra_names.append(name) + extra_bounds.append((param.lower, param.upper)) + if not extra_names: + return search_space_digest + # Insert source-only params before the task feature (which must + # remain the last column for MultiTaskGP / ImputedMultiTaskGP). + task_features = search_space_digest.task_features + if len(task_features) == 1: + tf_idx = task_features[0] + names = list(search_space_digest.feature_names) + bounds = list(search_space_digest.bounds) + names[tf_idx:tf_idx] = extra_names + bounds[tf_idx:tf_idx] = extra_bounds + # Task feature index shifts by the number of inserted params. + new_task_features = [tf_idx + len(extra_names)] + new_target_values = dict(search_space_digest.target_values) + if tf_idx in new_target_values: + new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx) + return dataclasses.replace( + search_space_digest, + feature_names=names, + bounds=bounds, + task_features=new_task_features, + target_values=new_target_values, + ) + # No task feature — just append. + return dataclasses.replace( + search_space_digest, + feature_names=search_space_digest.feature_names + extra_names, + bounds=search_space_digest.bounds + extra_bounds, + ) + def _fit( self, search_space: SearchSpace, @@ -525,6 +576,10 @@ def _fit( experiment_data=experiment_data, update_outcomes_and_parameters=True, ) + # Expand SSD bounds to cover source-only params from the joint search + # space. This ensures Normalize (and other input transforms) get bounds + # for the full feature space, not just the target dims. + search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) if experiment_data.arm_data.empty: self.outcomes = outcomes # Temporarily set datasets to None. We will construct empty datasets @@ -567,6 +622,7 @@ def _cross_validate( experiment_data=cv_training_data, update_outcomes_and_parameters=False, ) + search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) # Add the task feature to SSD, to ensure that a multi-task model is selected. if len(search_space_digest.task_features) > 1: raise UnsupportedError( @@ -633,6 +689,18 @@ def gen( if fixed_features is None: fixed_features = ObservationFeatures(parameters={}) fixed_features.parameters.setdefault(name, target_p.value) + # Fix source-only params during acquisition optimization so the + # optimizer doesn't search over dims that only exist in sources. + # The fixed value is irrelevant: LearnedFeatureImputation.transform + # unconditionally overwrites source-only columns with learned + # imputation values (see lines 2173-2176 in input.py). + for name, joint_p in self.joint_search_space.parameters.items(): + if name not in search_space.parameters and isinstance( + joint_p, RangeParameter + ): + if fixed_features is None: + fixed_features = ObservationFeatures(parameters={}) + fixed_features.parameters.setdefault(name, joint_p.lower) generator_run = super().gen( n=n, search_space=search_space, @@ -719,12 +787,8 @@ def transfer_learning_generator_specs_constructor( selector in case there is model selection enabled. """ input_transform_classes: list[type[InputTransform]] = [Normalize] - input_transform_options = { - "Normalize": { - # None for bounds here ensures we do not use bounds from - # the search space digest. - "bounds": None, - } + input_transform_options: dict[str, dict[str, Any]] = { + "Normalize": {}, } transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans transform_configs = get_derelativize_config( diff --git a/ax/generators/torch/botorch_modular/surrogate.py b/ax/generators/torch/botorch_modular/surrogate.py index 3b525d41d84..7699add58a9 100644 --- a/ax/generators/torch/botorch_modular/surrogate.py +++ b/ax/generators/torch/botorch_modular/surrogate.py @@ -742,7 +742,10 @@ def fit( # the feature names from the search space digest. Otherwise we only # keep the model within self._submodels as it may be models fitted on # auxiliary data such as the preference model for BOPE - if set(dataset.feature_names) == feature_names_set: + if set(dataset.feature_names) == feature_names_set or ( + isinstance(dataset, MultiTaskDataset) + and set(dataset.feature_names).issubset(feature_names_set) + ): models.append(model) outcome_names.extend(dataset.outcome_names) From d82837dd0ed3b99f5ead3f2c327a9927bfa9d05c Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 21 Apr 2026 16:17:15 -0700 Subject: [PATCH 5/5] Add [0, 1] bounds to LearnedFeatureImputation in input transform argparse (#5185) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/5185 The LearnedFeatureImputation transform supports constraining imputation values to a bounded range via sigmoid, but the Ax input transform argparse dispatcher never passed bounds. This left imputation values unconstrained, allowing them to drift far outside [0, 1] during MLL optimization. Pass bounds=[[0,...,0], [1,...,1]] to constrain imputation values to the normalized input range, since the preceding Normalize transform maps features to [0, 1]. Differential Revision: D101058043 --- .../input_constructors/input_transforms.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py index 0f6a270048a..8c12fcfe743 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py @@ -343,7 +343,7 @@ def _input_transform_argparse_learned_feature_imputation( Returns: A dictionary with ``feature_indices``, ``d``, ``task_feature_index``, - ``device``, and ``dtype`` keys. + ``bounds``, ``device``, and ``dtype`` keys. """ if not isinstance(dataset, MultiTaskDataset): raise ValueError( @@ -382,12 +382,23 @@ def _input_transform_argparse_learned_feature_imputation( for task_idx, ds in enumerate(all_datasets) } + dtype = torch_dtype or torch.float64 + # Constrain imputation values to [0, 1] since the preceding Normalize + # maps features to this range. Without bounds, imputation values are + # unconstrained and can drift far from the valid input range. + bounds = torch.stack( + [ + torch.zeros(d, dtype=dtype, device=torch_device), + torch.ones(d, dtype=dtype, device=torch_device), + ] + ) kwargs: dict[str, Any] = { "feature_indices": feature_indices, "d": d, "task_feature_index": task_feature_index, + "bounds": bounds, "device": torch_device, - "dtype": torch_dtype or torch.float64, + "dtype": dtype, } kwargs.update(input_transform_options) return kwargs