diff --git a/numpy_ml/tests/test_nonparametric.py b/numpy_ml/tests/test_nonparametric.py index 9e2ec7e..4119c0c 100644 --- a/numpy_ml/tests/test_nonparametric.py +++ b/numpy_ml/tests/test_nonparametric.py @@ -109,7 +109,7 @@ def test_gp_regression(N=15): preds, _ = gp.predict(X_test) gold_preds = gold.predict(X_test) - np.testing.assert_almost_equal(preds, gold_preds) + np.testing.assert_almost_equal(preds.squeeze(), gold_preds.squeeze()) mll = gp.marginal_log_likelihood() gold_mll = gold.log_marginal_likelihood()