Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions ax/adapter/transfer_learning/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import dataclasses
import warnings
from collections.abc import Mapping, Sequence
from logging import Logger
Expand Down Expand Up @@ -38,7 +39,7 @@
from ax.core.observation import ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import FixedParameter, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.search_space import SearchSpace, SearchSpaceDigest
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
from ax.generation_strategy.best_model_selector import (
ReductionCriterion,
Expand Down Expand Up @@ -504,6 +505,56 @@ def _get_task_datasets(
)
return task_datasets

def _expand_ssd_to_joint_space(
self,
search_space_digest: SearchSpaceDigest,
) -> SearchSpaceDigest:
"""Expand SSD bounds and feature_names to cover the joint search space.

The SSD produced by ``_get_fit_args`` reflects the target search space.
When source experiments have additional parameters, the model operates
in the full joint feature space (via ``ImputedMultiTaskGP`` /
``HeterogeneousMTGP``). This method appends bounds and feature names
for source-only parameters so that input transforms (Normalize,
LearnedFeatureImputation) receive correct full-space bounds.
"""
existing_names = set(search_space_digest.feature_names)
extra_names: list[str] = []
extra_bounds: list[tuple[int | float, int | float]] = []
for name, param in self.joint_search_space.parameters.items():
if name not in existing_names and isinstance(param, RangeParameter):
extra_names.append(name)
extra_bounds.append((param.lower, param.upper))
if not extra_names:
return search_space_digest
# Insert source-only params before the task feature (which must
# remain the last column for MultiTaskGP / ImputedMultiTaskGP).
task_features = search_space_digest.task_features
if len(task_features) == 1:
tf_idx = task_features[0]
names = list(search_space_digest.feature_names)
bounds = list(search_space_digest.bounds)
names[tf_idx:tf_idx] = extra_names
bounds[tf_idx:tf_idx] = extra_bounds
# Task feature index shifts by the number of inserted params.
new_task_features = [tf_idx + len(extra_names)]
new_target_values = dict(search_space_digest.target_values)
if tf_idx in new_target_values:
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
return dataclasses.replace(
search_space_digest,
feature_names=names,
bounds=bounds,
task_features=new_task_features,
target_values=new_target_values,
)
# No task feature — just append.
return dataclasses.replace(
search_space_digest,
feature_names=search_space_digest.feature_names + extra_names,
bounds=search_space_digest.bounds + extra_bounds,
)

def _fit(
self,
search_space: SearchSpace,
Expand All @@ -525,6 +576,10 @@ def _fit(
experiment_data=experiment_data,
update_outcomes_and_parameters=True,
)
# Expand SSD bounds to cover source-only params from the joint search
# space. This ensures Normalize (and other input transforms) get bounds
# for the full feature space, not just the target dims.
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
if experiment_data.arm_data.empty:
self.outcomes = outcomes
# Temporarily set datasets to None. We will construct empty datasets
Expand Down Expand Up @@ -567,6 +622,7 @@ def _cross_validate(
experiment_data=cv_training_data,
update_outcomes_and_parameters=False,
)
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
# Add the task feature to SSD, to ensure that a multi-task model is selected.
if len(search_space_digest.task_features) > 1:
raise UnsupportedError(
Expand Down Expand Up @@ -633,6 +689,18 @@ def gen(
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, target_p.value)
# Fix source-only params during acquisition optimization so the
# optimizer doesn't search over dims that only exist in sources.
# The fixed value is irrelevant: LearnedFeatureImputation.transform
# unconditionally overwrites source-only columns with learned
# imputation values (see lines 2173-2176 in input.py).
for name, joint_p in self.joint_search_space.parameters.items():
if name not in search_space.parameters and isinstance(
joint_p, RangeParameter
):
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, joint_p.lower)
generator_run = super().gen(
n=n,
search_space=search_space,
Expand Down Expand Up @@ -719,12 +787,8 @@ def transfer_learning_generator_specs_constructor(
selector in case there is model selection enabled.
"""
input_transform_classes: list[type[InputTransform]] = [Normalize]
input_transform_options = {
"Normalize": {
# None for bounds here ensures we do not use bounds from
# the search space digest.
"bounds": None,
}
input_transform_options: dict[str, dict[str, Any]] = {
"Normalize": {},
}
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
transform_configs = get_derelativize_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from botorch.models.transforms.input import (
FilterFeatures,
InputTransform,
LearnedFeatureImputation,
Normalize,
Warp,
)
Expand Down Expand Up @@ -314,3 +315,90 @@ def _input_transform_argparse_filter_features(
)

return input_transform_options_copy


@input_transform_argparse.register(LearnedFeatureImputation)
def _input_transform_argparse_learned_feature_imputation(
input_transform_class: type[LearnedFeatureImputation],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
input_transform_options: dict[str, Any] | None = None,
torch_device: torch.device | None = None,
torch_dtype: torch.dtype | None = None,
) -> dict[str, Any]:
"""Extract LearnedFeatureImputation kwargs from a MultiTaskDataset.

Computes ``feature_indices`` and ``d`` from the heterogeneous feature sets
in the dataset, following the same convention as
``HeterogeneousMTGP.construct_inputs``: the target task is placed first,
and the task feature column is excluded from the feature union.

Args:
input_transform_class: Input transform class.
dataset: A ``MultiTaskDataset`` with heterogeneous features.
search_space_digest: Search space digest.
input_transform_options: Optional overrides for transform kwargs.
torch_device: Device for the transform parameters.
torch_dtype: Dtype for the transform parameters.

Returns:
A dictionary with ``feature_indices``, ``d``, ``task_feature_index``,
``bounds``, ``device``, and ``dtype`` keys.
"""
if not isinstance(dataset, MultiTaskDataset):
raise ValueError(
"LearnedFeatureImputation requires a MultiTaskDataset, "
f"got {type(dataset).__name__}."
)
if not dataset.has_heterogeneous_features:
raise ValueError(
"LearnedFeatureImputation requires a MultiTaskDataset with "
"heterogeneous features (has_heterogeneous_features=True)."
)
input_transform_options = input_transform_options or {}

# Order datasets: target first, then remaining (same as HeterogeneousMTGP).
child_datasets = dataset.datasets.copy()
target_dataset = child_datasets.pop(dataset.target_outcome_name)
all_datasets = [target_dataset] + list(child_datasets.values())

# Use target's feature order as canonical (NO alphabetical sort).
# Source-only features are appended at the end.
task_feature_index = (
dataset.task_feature_index if (dataset.task_feature_index is not None) else -1
)
all_features: list[str] = list(target_dataset.feature_names[:task_feature_index])
for ds in all_datasets[1:]:
for fn in ds.feature_names[:task_feature_index]:
if fn not in all_features:
all_features.append(fn)
d = len(all_features)

# Map each task's features to indices in the global feature space.
feature_indices = {
task_idx: [
all_features.index(fn) for fn in ds.feature_names[:task_feature_index]
]
for task_idx, ds in enumerate(all_datasets)
}

dtype = torch_dtype or torch.float64
# Constrain imputation values to [0, 1] since the preceding Normalize
# maps features to this range. Without bounds, imputation values are
# unconstrained and can drift far from the valid input range.
bounds = torch.stack(
[
torch.zeros(d, dtype=dtype, device=torch_device),
torch.ones(d, dtype=dtype, device=torch_device),
]
)
kwargs: dict[str, Any] = {
"feature_indices": feature_indices,
"d": d,
"task_feature_index": task_feature_index,
"bounds": bounds,
"device": torch_device,
"dtype": dtype,
}
kwargs.update(input_transform_options)
return kwargs
5 changes: 4 additions & 1 deletion ax/generators/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,10 @@ def fit(
# the feature names from the search space digest. Otherwise we only
# keep the model within self._submodels as it may be models fitted on
# auxiliary data such as the preference model for BOPE
if set(dataset.feature_names) == feature_names_set:
if set(dataset.feature_names) == feature_names_set or (
isinstance(dataset, MultiTaskDataset)
and set(dataset.feature_names).issubset(feature_names_set)
):
models.append(model)
outcome_names.extend(dataset.outcome_names)

Expand Down
22 changes: 8 additions & 14 deletions ax/generators/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ def choose_model_class(
dataset: The dataset on which the model will be fitted.
search_space_digest: The digest of the search space the model will be
fitted within.
specified_model_class: If provided, this model class will be used unless
overridden for specific cases (e.g., heterogeneous datasets).
specified_model_class: If provided, this model class will be used.

Returns:
A BoTorch `Model` class.
Expand All @@ -295,23 +294,18 @@ def choose_model_class(
"Multi-task multi-fidelity optimization not yet supported."
)

# Check for heterogeneous multi-task datasets & override model class if needed.
# Check for heterogeneous multi-task datasets. If a model class was
# explicitly specified, respect it; otherwise default to HeterogeneousMTGP.
if (
search_space_digest.task_features
and isinstance(dataset, MultiTaskDataset)
and dataset.has_heterogeneous_features
):
if (
specified_model_class is not None
and specified_model_class is not HeterogeneousMTGP
):
logger.warning(
f"Detected heterogeneous features in MultiTaskDataset. "
f"Overriding specified model class {specified_model_class.__name__} "
f"with HeterogeneousMTGP for transfer learning with "
f"heterogeneous search spaces."
)
model_class = HeterogeneousMTGP
model_class = (
specified_model_class
if specified_model_class is not None
else HeterogeneousMTGP
)
logger.debug(f"Chose BoTorch model class: {model_class}.")
return model_class

Expand Down
Loading
Loading