|
1 | 1 | # Code taken from https://botorch.org/tutorials/turbo_1 |
| 2 | +from __future__ import annotations |
| 3 | + |
2 | 4 | import numpy as np |
3 | 5 | from poli.core.abstract_black_box import AbstractBlackBox |
| 6 | +from typing import List |
4 | 7 |
|
5 | 8 | from poli_baselines.core.step_by_step_solver import StepByStepSolver |
6 | 9 | import math |
|
29 | 32 |
|
30 | 33 |
|
31 | 34 | class TurboWrapper(StepByStepSolver): |
32 | | - def __init__(self, black_box: AbstractBlackBox, x0: np.ndarray, y0: np.ndarray): |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + black_box: AbstractBlackBox, |
| 38 | + x0: np.ndarray, |
| 39 | + y0: np.ndarray, |
| 40 | + bounds: np.ndarray | None = None, |
| 41 | + ): |
| 42 | + """ |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + black_box |
| 47 | + x0 |
| 48 | + y0 |
| 49 | + bounds: |
| 50 | + array of shape Dx2 where D is the dimensionality |
| 51 | + The first row contains the lower bounds on x, the last row contains the upper bounds. |
| 52 | + """ |
33 | 53 | super().__init__(black_box, x0, y0) |
34 | | - self.X_turbo = torch.tensor(x0) |
| 54 | + assert x0.shape[0] > 1 |
| 55 | + |
| 56 | + if bounds is None: |
| 57 | + bounds = np.array([[x0.min() - 1.0, x0.max() + 1.0]] * x0.shape[1]) |
| 58 | + |
| 59 | + assert bounds.shape[1] == 2 |
| 60 | + assert bounds.shape[0] == x0.shape[1] |
| 61 | + assert np.all(bounds[:, 1] >= bounds[:, 0]) |
| 62 | + bounds[:, 1] -= bounds[:, 0] |
| 63 | + |
| 64 | + def make_transforms(): |
| 65 | + to_turbo = lambda X: (X - bounds[:, 0]) / bounds[:, 1] |
| 66 | + from_turbo = lambda X: X * bounds[:, 1] + bounds[:, 0] |
| 67 | + return to_turbo, from_turbo |
| 68 | + |
| 69 | + self.to_turbo, self.from_turbo = make_transforms() |
| 70 | + self.X_turbo = torch.tensor(self.to_turbo(x0)) |
35 | 71 | self.Y_turbo = torch.tensor(y0) |
36 | 72 | self.batch_size = 1 |
37 | 73 | dim = x0.shape[1] |
@@ -70,14 +106,14 @@ def next_candidate(self) -> np.ndarray: |
70 | 106 | raw_samples=RAW_SAMPLES, |
71 | 107 | acqf="ts", |
72 | 108 | ) |
73 | | - return X_next |
| 109 | + return self.from_turbo(X_next.numpy()) |
74 | 110 |
|
75 | 111 | def post_update(self, x: np.ndarray, y: np.ndarray) -> None: |
76 | 112 | """ |
77 | 113 | This method is called after the history is updated. |
78 | 114 | """ |
79 | 115 | Y_next = torch.tensor(y) |
80 | | - X_next = torch.tensor(x) |
| 116 | + X_next = torch.tensor(self.to_turbo(x)) |
81 | 117 |
|
82 | 118 | # Update state |
83 | 119 | self.state = update_state(state=self.state, Y_next=Y_next) |
|
0 commit comments