diff --git a/gpflowopt/acquisition/acquisition.py b/gpflowopt/acquisition/acquisition.py index 4d3b58d..2caefcf 100644 --- a/gpflowopt/acquisition/acquisition.py +++ b/gpflowopt/acquisition/acquisition.py @@ -396,27 +396,31 @@ class MCMCAcquistion(AcquisitionSum): """ Apply MCMC over the hyperparameters of an acquisition function (= over the hyperparameters of the contained models). - The models passed into an object of this class are optimized with MLE, and then further sampled with HMC. - These hyperparameter samples are then set in copies of the acquisition. + The models passed into an object of this class are optimized with MLE (fast burn-in), and then further sampled with + HMC. These hyperparameter samples are then set in copies of the acquisition. For evaluating the underlying acquisition function, the predictions of the acquisition copies are averaged. """ def __init__(self, acquisition, n_slices, **kwargs): assert isinstance(acquisition, Acquisition) assert n_slices > 0 - - copies = [copy.deepcopy(acquisition) for _ in range(n_slices - 1)] - for c in copies: - c.optimize_restarts = 0 - # the call to the constructor of the parent classes, will optimize acquisition, so it obtains the MLE solution. - super(MCMCAcquistion, self).__init__([acquisition] + copies) + super(MCMCAcquistion, self).__init__([acquisition]*n_slices) + self._needs_new_copies = True self._sample_opt = kwargs def _optimize_models(self): # Optimize model #1 self.operands[0]._optimize_models() + # Copy it again if needed due to changed free state + if self._needs_new_copies: + new_copies = [copy.deepcopy(self.operands[0]) for _ in range(len(self.operands) - 1)] + for c in new_copies: + c.optimize_restarts = 0 + self.operands = ParamList([self.operands[0]] + new_copies) + self._needs_new_copies = False + # Draw samples using HMC # Sample each model of the acquisition function - results in a list of 2D ndarrays. hypers = np.hstack([model.sample(len(self.operands), **self._sample_opt) for model in self.models]) @@ -440,3 +444,13 @@ def set_data(self, X, Y): def build_acquisition(self, Xcand): # Average the predictions of the copies. return 1. / len(self.operands) * super(MCMCAcquistion, self).build_acquisition(Xcand) + + def _kill_autoflow(self): + """ + Flag for recreation on next optimize. + + Following the recompilation of models, the free state might have changed. This means updating the samples can + cause inconsistencies and errors. + """ + super(MCMCAcquistion, self)._kill_autoflow() + self._needs_new_copies = True diff --git a/gpflowopt/bo.py b/gpflowopt/bo.py index 8071104..749a7c8 100644 --- a/gpflowopt/bo.py +++ b/gpflowopt/bo.py @@ -16,12 +16,40 @@ import numpy as np from scipy.optimize import OptimizeResult +import tensorflow as tf +from gpflow.gpr import GPR from .acquisition import Acquisition, MCMCAcquistion -from .optim import Optimizer, SciPyOptimizer -from .objective import ObjectiveWrapper from .design import Design, EmptyDesign +from .objective import ObjectiveWrapper +from .optim import Optimizer, SciPyOptimizer from .pareto import non_dominated_sort +from .models import ModelWrapper + + +def jitchol_callback(models): + """ + Increase the likelihood in case of Cholesky failures. + + This is similar to the use of jitchol in GPy. Default callback for BayesianOptimizer. + Only usable on GPR models, other types are ignored. + """ + for m in np.atleast_1d(models): + if isinstance(m, ModelWrapper): + jitchol_callback(m.wrapped) # pragma: no cover + + if not isinstance(m, GPR): + continue + + s = m.get_free_state() + eKdiag = np.mean(np.diag(m.kern.compute_K_symm(m.X.value))) + for e in [0] + [10**ex for ex in range(-6,-1)]: + try: + m.likelihood.variance = m.likelihood.variance.value + e * eKdiag + m.optimize(maxiter=5) + break + except tf.errors.InvalidArgumentError: # pragma: no cover + m.set_state(s) class BayesianOptimizer(Optimizer): @@ -32,7 +60,8 @@ class BayesianOptimizer(Optimizer): Additionally, it is configured with a separate optimizer for the acquisition function. """ - def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=True, hyper_draws=None): + def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=True, hyper_draws=None, + callback=jitchol_callback): """ :param Domain domain: The optimization space. :param Acquisition acquisition: The acquisition function to optimize over the domain. @@ -51,6 +80,12 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr are obtained using Hamiltonian MC. (see `GPflow documentation `_ for details) for each model. The acquisition score is computed for each draw, and averaged. + :param callable callback: (optional) this function or object will be called, after the + data of all models has been updated with all models as retrieved by acquisition.models as argument without + the wrapping model handling any scaling . This allows custom model optimization strategies to be implemented. + All manipulations of GPflow models are permitted. Combined with the optimize_restarts parameter of + :class:`~.Acquisition` this allows several scenarios: do the optimization manually from the callback + (optimize_restarts equals 0), or choose the starting point + some random restarts (optimize_restarts > 0). """ assert isinstance(acquisition, Acquisition) assert hyper_draws is None or hyper_draws > 0 @@ -69,6 +104,8 @@ def __init__(self, domain, acquisition, optimizer=None, initial=None, scaling=Tr initial = initial or EmptyDesign(domain) self.set_initial(initial.generate()) + self._model_callback = callback + @Optimizer.domain.setter def domain(self, dom): assert self.domain.size == dom.size @@ -86,6 +123,8 @@ def _update_model_data(self, newX, newY): assert self.acquisition.data[0].shape[1] == newX.shape[-1] assert self.acquisition.data[1].shape[1] == newY.shape[-1] assert newX.shape[0] == newY.shape[0] + if newX.size == 0: + return X = np.vstack((self.acquisition.data[0], newX)) Y = np.vstack((self.acquisition.data[1], newY)) self.acquisition.set_data(X, Y) @@ -174,7 +213,6 @@ def _optimize(self, fx, n_iter): :param n_iter: number of iterations to run :return: OptimizeResult object """ - assert isinstance(fx, ObjectiveWrapper) # Evaluate and add the initial design (if any) @@ -190,6 +228,10 @@ def inverse_acquisition(x): # Optimization loop for i in range(n_iter): + # If a callback is specified, and acquisition has the setup flag enabled (indicating an upcoming + # compilation), run the callback. + if self._model_callback and self.acquisition._needs_setup: + self._model_callback([m.wrapped for m in self.acquisition.models]) result = self.optimizer.optimize(inverse_acquisition) self._update_model_data(result.x, fx(result.x)) diff --git a/testing/test_acquisition.py b/testing/test_acquisition.py index aa8f29c..cc81e6c 100644 --- a/testing/test_acquisition.py +++ b/testing/test_acquisition.py @@ -146,7 +146,6 @@ def test_object_integrity(self, acquisition): for oper in acquisition.operands: self.assertTrue(isinstance(oper, gpflowopt.acquisition.Acquisition), msg="All operands should be an acquisition object") - self.assertTrue(all(isinstance(m, gpflowopt.models.ModelWrapper) for m in acquisition.models)) @parameterized.expand(list(zip(aggregations))) @@ -218,9 +217,23 @@ def test_marginalized_score(self, acquisition): ei_mcmc = acquisition.evaluate(Xt) np.testing.assert_almost_equal(ei_mle, ei_mcmc, decimal=5) - @parameterized.expand(list(zip([aggregations[2]]))) - def test_mcmc_acq_models(self, acquisition): + def test_mcmc_acq(self): + acquisition = gpflowopt.acquisition.MCMCAcquistion( + gpflowopt.acquisition.ExpectedImprovement(create_parabola_model(domain)), 10) + for oper in acquisition.operands: + self.assertListEqual(acquisition.models, oper.models) + self.assertEqual(acquisition.operands[0], oper) + self.assertTrue(acquisition._needs_new_copies) + acquisition._optimize_models() self.assertListEqual(acquisition.models, acquisition.operands[0].models) + for oper in acquisition.operands[1:]: + self.assertNotEqual(acquisition.operands[0], oper) + self.assertFalse(acquisition._needs_new_copies) + acquisition._setup() + Xt = np.random.rand(20, 2) * 2 - 1 + ei_mle = acquisition.operands[0].evaluate(Xt) + ei_mcmc = acquisition.evaluate(Xt) + np.testing.assert_almost_equal(ei_mle, ei_mcmc, decimal=5) class TestJointAcquisition(unittest.TestCase): diff --git a/testing/test_optimizers.py b/testing/test_optimizers.py index 8271099..5b13a18 100644 --- a/testing/test_optimizers.py +++ b/testing/test_optimizers.py @@ -214,8 +214,8 @@ def test_optimize_multi_objective(self): result = optimizer.optimize(vlmop2, n_iter=2) self.assertTrue(result.success) self.assertEqual(result.nfev, 2, "Only 2 evaluations permitted") - self.assertTupleEqual(result.x.shape, (9, 2)) - self.assertTupleEqual(result.fun.shape, (9, 2)) + self.assertTupleEqual(result.x.shape, (7, 2)) + self.assertTupleEqual(result.fun.shape, (7, 2)) _, dom = gpflowopt.pareto.non_dominated_sort(result.fun) self.assertTrue(np.all(dom==0)) @@ -288,6 +288,71 @@ def test_mcmc(self): self.assertTrue(np.allclose(result.x, 0), msg="Optimizer failed to find optimum") self.assertTrue(np.allclose(result.fun, 0), msg="Incorrect function value returned") + def test_callback(self): + class DummyCallback(object): + def __init__(self): + self.counter = 0 + + def __call__(self, models): + self.counter += 1 + + c = DummyCallback() + optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, callback=c) + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=2) + self.assertEqual(c.counter, 2) + + def test_callback_recompile(self): + class DummyCallback(object): + def __init__(self): + self.recompile = False + + def __call__(self, models): + c = np.random.randint(2, 10) + models[0].kern.variance.prior = gpflow.priors.Gamma(c, 1./c) + self.recompile = models[0]._needs_recompile + + c = DummyCallback() + optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, callback=c) + self.acquisition.evaluate(np.zeros((1,2))) # Make sure its run and setup to skip + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1) + self.assertFalse(c.recompile) + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1) + self.assertTrue(c.recompile) + self.assertFalse(self.acquisition.models[0]._needs_recompile) + + def test_callback_recompile_mcmc(self): + class DummyCallback(object): + def __init__(self): + self.no_models = 0 + + def __call__(self, models): + c = np.random.randint(2, 10) + models[0].kern.variance.prior = gpflow.priors.Gamma(c, 1. / c) + self.no_models = len(models) + + c = DummyCallback() + optimizer = gpflowopt.BayesianOptimizer(self.domain, self.acquisition, hyper_draws=5, callback=c) + opers = optimizer.acquisition.operands + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1) + self.assertEqual(c.no_models, 1) + self.assertEqual(id(opers[0]), id(optimizer.acquisition.operands[0])) + for op1, op2 in zip(opers[1:], optimizer.acquisition.operands[1:]): + self.assertNotEqual(id(op1), id(op2)) + opers = optimizer.acquisition.operands + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1) + self.assertEqual(id(opers[0]), id(optimizer.acquisition.operands[0])) + for op1, op2 in zip(opers[1:], optimizer.acquisition.operands[1:]): + self.assertNotEqual(id(op1), id(op2)) + + def test_nongpr_model(self): + design = gpflowopt.design.LatinHyperCube(16, self.domain) + X, Y = design.generate(), parabola2d(design.generate())[0] + m = gpflow.vgp.VGP(X, Y, gpflow.kernels.RBF(2, ARD=True), likelihood=gpflow.likelihoods.Gaussian()) + acq = gpflowopt.acquisition.ExpectedImprovement(m) + optimizer = gpflowopt.BayesianOptimizer(self.domain, acq) + result = optimizer.optimize(lambda X: parabola2d(X)[0], n_iter=1) + self.assertTrue(result.success) + class TestSilentOptimization(unittest.TestCase): @contextmanager @@ -323,3 +388,4 @@ def _optimize(self, objective): opt.optimize(None) output = out.getvalue().strip() self.assertEqual(output, '') +