Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 93 additions & 37 deletions stan/math/mix/functor/laplace_likelihood.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/mix/functor/conditional_copy_and_promote.hpp>
#include <stan/math/prim/functor.hpp>
#include <stan/math/prim/fun.hpp>
#include <type_traits>

namespace stan {
namespace math {
Expand All @@ -16,6 +17,41 @@ namespace math {
namespace laplace_likelihood {

namespace internal {

/**
* Type trait to detect if a likelihood functor `F` provides a custom
* `diff` method that computes the gradient and negative Hessian
* analytically, avoiding the cost of embedded reverse-mode autodiff.
*
* A functor with a custom `diff` method should provide:
* auto diff(theta, hessian_block_size, args...) const
* returning std::pair<gradient, sparse_hessian>.
*/
template <typename F, typename = void>
struct has_custom_diff : std::false_type {};

template <typename F>
struct has_custom_diff<F, std::void_t<decltype(std::declval<const F&>().diff(
std::declval<const Eigen::VectorXd&>(), 1))>>
: std::true_type {};

template <typename F>
inline constexpr bool has_custom_diff_v = has_custom_diff<F>::value;

/**
* Type trait to detect if a likelihood functor `F` provides a custom
* `third_diff` method for the third derivative w.r.t. theta.
*/
template <typename F, typename = void>
struct has_custom_third_diff : std::false_type {};

template <typename F>
struct has_custom_third_diff<
F, std::void_t<decltype(std::declval<const F&>().third_diff(
std::declval<const Eigen::VectorXd&>()))>> : std::true_type {};

template <typename F>
inline constexpr bool has_custom_third_diff_v = has_custom_third_diff<F>::value;
/**
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
* @tparam Theta A class assignable to an Eigen vector type
Expand Down Expand Up @@ -158,6 +194,8 @@ inline auto block_hessian(F&& f, Theta&& theta,
* `theta` and `args...`
* @note If `Args` contains \ref var types then their adjoints will be
* calculated as a side effect.
* @note If `F` provides a custom `diff` method, it will be used instead
* of the generic autodiff path for better performance.
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
* @tparam Theta A class assignable to an Eigen vector type
* @tparam Stream Type of stream for messages.
Expand All @@ -174,40 +212,50 @@ template <typename F, typename Theta, typename Stream, typename... Args,
require_eigen_vector_vt<std::is_arithmetic, Theta>* = nullptr>
inline auto diff(F&& f, Theta&& theta, const Eigen::Index hessian_block_size,
Stream* msgs, Args&&... args) {
using Eigen::Dynamic;
using Eigen::Matrix;
const Eigen::Index theta_size = theta.size();
auto theta_gradient = [&theta, &f, &msgs](auto&&... args) {
nested_rev_autodiff nested;
Matrix<var, Dynamic, 1> theta_var = theta;
var f_var = f(theta_var, args..., msgs);
grad(f_var.vi_);
return theta_var.adj().eval();
}(args...);
if (hessian_block_size == 1) {
auto v = Eigen::VectorXd::Ones(theta_size);
Eigen::VectorXd hessian_v = Eigen::VectorXd::Zero(theta_size);
hessian_times_vector(f, hessian_v, std::forward<Theta>(theta), std::move(v),
value_of(args)..., msgs);
Eigen::SparseMatrix<double> hessian_theta(theta_size, theta_size);
hessian_theta.reserve(Eigen::VectorXi::Constant(theta_size, 1));
for (Eigen::Index i = 0; i < theta_size; i++) {
hessian_theta.insert(i, i) = hessian_v(i);
}
return std::make_pair(std::move(theta_gradient), (-hessian_theta).eval());
using F_t = std::decay_t<F>;
if constexpr (has_custom_diff_v<F_t>) {
// Use the functor's specialized analytic derivatives
return f.diff(std::forward<Theta>(theta), hessian_block_size,
std::forward<Args>(args)...);
} else {
return std::make_pair(
std::move(theta_gradient),
(-hessian_block_diag(f, std::forward<Theta>(theta), hessian_block_size,
value_of(args)..., msgs))
.eval());
// Fall back to generic autodiff
using Eigen::Dynamic;
using Eigen::Matrix;
const Eigen::Index theta_size = theta.size();
auto theta_gradient = [&theta, &f, &msgs](auto&&... args) {
nested_rev_autodiff nested;
Matrix<var, Dynamic, 1> theta_var = theta;
var f_var = f(theta_var, args..., msgs);
grad(f_var.vi_);
return theta_var.adj().eval();
}(args...);
if (hessian_block_size == 1) {
auto v = Eigen::VectorXd::Ones(theta_size);
Eigen::VectorXd hessian_v = Eigen::VectorXd::Zero(theta_size);
hessian_times_vector(f, hessian_v, std::forward<Theta>(theta),
std::move(v), value_of(args)..., msgs);
Eigen::SparseMatrix<double> hessian_theta(theta_size, theta_size);
hessian_theta.reserve(Eigen::VectorXi::Constant(theta_size, 1));
for (Eigen::Index i = 0; i < theta_size; i++) {
hessian_theta.insert(i, i) = hessian_v(i);
}
return std::make_pair(std::move(theta_gradient), (-hessian_theta).eval());
} else {
return std::make_pair(
std::move(theta_gradient),
(-hessian_block_diag(f, std::forward<Theta>(theta),
hessian_block_size, value_of(args)..., msgs))
.eval());
}
}
}

/**
* Compute third order derivative of `f` wrt `theta` and `args...`
* @note If `Args` contains \ref var types then their adjoints will be
* calculated as a side effect.
* @note If `F` provides a custom `third_diff` method, it will be used
* instead of the generic `fvar<fvar<var>>` autodiff path.
* @tparam F A functor with `opertor()(Args&&...)` returning a scalar
* @tparam Theta A class assignable to an Eigen vector type
* @tparam Stream Type of stream for messages.
Expand All @@ -221,18 +269,26 @@ template <typename F, typename Theta, typename Stream, typename... Args,
require_eigen_vector_t<Theta>* = nullptr>
inline Eigen::VectorXd third_diff(F&& f, Theta&& theta, Stream&& msgs,
Args&&... args) {
nested_rev_autodiff nested;
const Eigen::Index theta_size = theta.size();
arena_t<Eigen::Matrix<var, Eigen::Dynamic, 1>> theta_var
= std::forward<Theta>(theta);
arena_t<Eigen::Matrix<fvar<fvar<var>>, Eigen::Dynamic, 1>> theta_ffvar(
theta_size);
for (Eigen::Index i = 0; i < theta_size; ++i) {
theta_ffvar(i) = fvar<fvar<var>>(fvar<var>(theta_var(i), 1.0), 1.0);
using F_t = std::decay_t<F>;
if constexpr (has_custom_third_diff_v<F_t>) {
// Use the functor's specialized analytic third derivative
return f.third_diff(std::forward<Theta>(theta),
std::forward<Args>(args)...);
} else {
// Fall back to generic fvar<fvar<var>> autodiff
nested_rev_autodiff nested;
const Eigen::Index theta_size = theta.size();
arena_t<Eigen::Matrix<var, Eigen::Dynamic, 1>> theta_var
= std::forward<Theta>(theta);
arena_t<Eigen::Matrix<fvar<fvar<var>>, Eigen::Dynamic, 1>> theta_ffvar(
theta_size);
for (Eigen::Index i = 0; i < theta_size; ++i) {
theta_ffvar(i) = fvar<fvar<var>>(fvar<var>(theta_var(i), 1.0), 1.0);
}
fvar<fvar<var>> ftheta_ffvar = f(theta_ffvar, args..., msgs);
grad(ftheta_ffvar.d_.d_.vi_);
return theta_var.adj().eval();
}
fvar<fvar<var>> ftheta_ffvar = f(theta_ffvar, args..., msgs);
grad(ftheta_ffvar.d_.d_.vi_);
return theta_var.adj().eval();
}

/**
Expand Down
109 changes: 109 additions & 0 deletions stan/math/mix/prob/laplace_marginal_neg_binomial_2_log_lpmf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,115 @@ struct neg_binomial_2_log_likelihood {
elt_multiply(multiply(n_per_group, eta),
subtract(log_eta, lse))));
}

/**
* Compute gradient and negative Hessian of the neg_binomial_2_log
* likelihood analytically, avoiding nested autodiff.
*
* @param theta Latent Gaussian variable (double).
* @param hessian_block_size Size of each diagonal block (typically 1).
* @param eta Dispersion parameter (scalar or 1-element vector).
* @param y Observed counts.
* @param y_index Group index for each observation.
* @param mean Mean offset for theta.
* @return pair of (gradient, negative Hessian) as (VectorXd, SparseMatrix).
*/
template <typename Mean>
inline auto diff(const Eigen::VectorXd& theta, int hessian_block_size,
const Eigen::VectorXd& eta, const std::vector<int>& y,
const std::vector<int>& y_index, Mean&& mean) const {
const int theta_size = theta.size();
const double eta_scalar = eta(0);

Eigen::VectorXd sums = Eigen::VectorXd::Zero(theta_size);
Eigen::VectorXd n_samples = Eigen::VectorXd::Zero(theta_size);
for (size_t i = 0; i < y.size(); i++) {
n_samples(y_index[i] - 1) += 1;
sums(y_index[i] - 1) += y[i];
}

// theta + mean
Eigen::VectorXd theta_offset = add(theta, value_of(mean));

// exp(-theta_offset)
Eigen::VectorXd exp_neg_theta = exp(-theta_offset);
// sums + eta * n_samples
Eigen::VectorXd sums_plus_n_eta = sums + eta_scalar * n_samples;
// 1 + eta * exp(-theta_offset)
Eigen::VectorXd one_plus_exp
= Eigen::VectorXd::Ones(theta_size) + eta_scalar * exp_neg_theta;

// Gradient: sums - (sums + eta * n) / (1 + eta * exp(-theta))
Eigen::VectorXd gradient
= sums - sums_plus_n_eta.cwiseQuotient(one_plus_exp);

// Negative Hessian diagonal:
// eta * (sums + eta * n) * exp(-theta) / (1 + eta * exp(-theta))^2
Eigen::VectorXd hessian_diag
= eta_scalar
* sums_plus_n_eta.cwiseProduct(exp_neg_theta.cwiseQuotient(
one_plus_exp.cwiseProduct(one_plus_exp)));

Eigen::SparseMatrix<double> hessian(theta_size, theta_size);
hessian.reserve(Eigen::VectorXi::Constant(theta_size, hessian_block_size));
for (int i = 0; i < theta_size; i++) {
hessian.insert(i, i) = hessian_diag(i);
}

return std::make_pair(std::move(gradient), std::move(hessian));
}

/**
* Compute the third derivative of the neg_binomial_2_log likelihood
* w.r.t. theta analytically, avoiding fvar<fvar<var>> autodiff.
*
* The third derivative is:
* d^3/dtheta^3 log p(y|theta,eta) =
* -(sums + eta*n) * eta * exp(theta) * (eta - exp(theta))
* / (eta + exp(theta))^3
*
* @param theta Latent Gaussian variable (double).
* @param eta Dispersion parameter (scalar or 1-element vector).
* @param y Observed counts.
* @param y_index Group index for each observation.
* @param mean Mean offset for theta.
* @return Third derivative as a VectorXd.
*/
template <typename Mean>
inline Eigen::VectorXd third_diff(const Eigen::VectorXd& theta,
const Eigen::VectorXd& eta,
const std::vector<int>& y,
const std::vector<int>& y_index,
Mean&& mean) const {
const int theta_size = theta.size();
const double eta_scalar = eta(0);

Eigen::VectorXd sums = Eigen::VectorXd::Zero(theta_size);
Eigen::VectorXd n_samples = Eigen::VectorXd::Zero(theta_size);
for (size_t i = 0; i < y.size(); i++) {
n_samples(y_index[i] - 1) += 1;
sums(y_index[i] - 1) += y[i];
}

// theta + mean
Eigen::VectorXd theta_offset = add(theta, value_of(mean));

Eigen::VectorXd exp_theta = exp(theta_offset);
Eigen::VectorXd eta_vec = Eigen::VectorXd::Constant(theta_size, eta_scalar);
Eigen::VectorXd eta_plus_exp_theta = eta_vec + exp_theta;

// -(sums + eta*n) * eta * exp(theta) * (eta - exp(theta))
// / (eta + exp(theta))^3
Eigen::VectorXd eta_plus_exp_theta_sq
= eta_plus_exp_theta.cwiseProduct(eta_plus_exp_theta);
Eigen::VectorXd eta_plus_exp_theta_cubed
= eta_plus_exp_theta_sq.cwiseProduct(eta_plus_exp_theta);

return -((sums + eta_scalar * n_samples) * eta_scalar)
.cwiseProduct(exp_theta.cwiseProduct(
(eta_vec - exp_theta)
.cwiseQuotient(eta_plus_exp_theta_cubed)));
}
};

/**
Expand Down