Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 72 additions & 36 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ module TypeError = struct
| IllTypedLaplaceMarginal of string * bool * UnsizedType.argumentlist
| LaplaceCompatibilityIssue of string
| IlltypedLaplaceTooMany of string * int
| IlltypedLaplaceHessianBlockSize of
string * (UnsizedType.autodifftype * UnsizedType.t) option
| IlltypedLaplaceTolArgs of string * SignatureMismatch.function_mismatch
| AmbiguousFunctionPromotion of
string
Expand Down Expand Up @@ -80,14 +82,38 @@ module TypeError = struct
| 1 -> "first element of the control parameter tuple (initial guess)"
| 2 -> "second element of the control parameter tuple (tolerance)"
| 3 -> "third element of the control parameter tuple (max_num_steps)"
| 4 -> "fourth element of the control parameter tuple (hessian_block_size)"
| 5 -> "fifth element of the control parameter tuple (solver)"
| 6 ->
"sixth element of the control parameter tuple (max_steps_line_search)"
| 7 -> "seventh element of the control parameter tuple (allow_fallthrough)"
| 4 -> "fourth element of the control parameter tuple (solver)"
| 5 ->
"fifth element of the control parameter tuple (max_steps_line_search)"
| 6 -> "sixth element of the control parameter tuple (allow_fallthrough)"
| n ->
Fmt.str "%a element of the control parameter tuple" (Fmt.ordinal ()) n

let generic_laplace_usage info ppf (name, supplied) =
let req = Stan_math_signatures.laplace_helper_param_types name in
let is_helper = not @@ List.is_empty req in
let pp_lik_args ppf =
if is_helper then Fmt.(list ~sep:comma UnsizedType.pp_fun_arg) ppf req
else Fmt.pf ppf "(vector, T_l%t) => real,@ tuple(T_l%t)" ellipsis ellipsis
in
let pp_laplace_tols ppf =
if String.is_substring ~substring:"_tol" name then
Fmt.pf ppf ", %a"
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
Stan_math_signatures.laplace_tolerance_argument_types in
let pp_supplied_tys ppf =
if List.is_empty supplied then Fmt.nop ppf ()
else
Fmt.pf ppf "@ However, we received the types:@ @[<hov 2>(%a)@]"
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
supplied in
Fmt.pf ppf
"@[<v>Ill-typed arguments supplied to function %a.@ The valid signature \
of this function is@ @[<hov 2>%s(%t,@ data int,@ (T_k%t) => matrix,@ \
tuple(T_k%t)%t)@]%t@ @[%a@]@]"
quoted name name pp_lik_args ellipsis ellipsis pp_laplace_tols
pp_supplied_tys info ()

let rec expected_types : UnsizedType.t Common.Nonempty_list.t Fmt.t =
let ust = expected_style UnsizedType.pp in
fun ppf l ->
Expand Down Expand Up @@ -193,39 +219,24 @@ module TypeError = struct
details
Fmt.(list ~sep:comma (expected_style UnsizedType.pp_fun_arg))
expected
| IllTypedLaplaceMarginal (name, early, supplied) ->
let req = Stan_math_signatures.laplace_helper_param_types name in
let is_helper = not @@ List.is_empty req in
let info =
if early then
| IllTypedLaplaceMarginal (name, true, supplied) ->
let info ppf () =
Fmt.text ppf
"We were unable to start more in-depth checking. Please ensure you \
are passing enough arguments and that the first argument is a \
function."
else
let n = if is_helper then List.length req else 2 in
Fmt.str
"Typechecking failed after checking the first %d arguments. \
Please ensure you are passing enough arguments and that the %a \
is a function."
n (Fmt.ordinal ()) (n + 1) in
let pp_lik_args ppf =
if is_helper then Fmt.(list ~sep:comma UnsizedType.pp_fun_arg) ppf req
else
Fmt.pf ppf "(vector, T_l%t) => real,@ tuple(T_l%t)" ellipsis
ellipsis in
let pp_laplace_tols ppf =
if String.is_substring ~substring:"_tol" name then
Fmt.pf ppf ", %a"
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
Stan_math_signatures.laplace_tolerance_argument_types in
Fmt.pf ppf
"@[<v>Ill-typed arguments supplied to function %a.@ The valid \
signature of this function is@ @[<hov 2>%s(%t,@ vector,@ (T_k%t) => \
matrix,@ tuple(T_k%t)%t)@]@ However, we received the types:@ @[<hov \
2>(%a)@]@ @[%a@]@]"
quoted name name pp_lik_args ellipsis ellipsis pp_laplace_tols
Fmt.(list ~sep:comma UnsizedType.pp_fun_arg)
supplied Fmt.text info
function." in
generic_laplace_usage info ppf (name, supplied)
| IllTypedLaplaceMarginal (name, false, supplied) ->
let req = Stan_math_signatures.laplace_helper_param_types name in
let is_helper = not @@ List.is_empty req in
let info ppf () =
let n = (if is_helper then List.length req else 2) + 1 in
Fmt.pf ppf
"Typechecking failed after checking the first %d arguments.@ \
Please ensure you are passing enough arguments and that the %a is \
a function."
n (Fmt.ordinal ()) (n + 1) in
generic_laplace_usage info ppf (name, supplied)
| LaplaceCompatibilityIssue banned_function ->
Fmt.pf ppf
"The function %a, called by this likelihood function,@ does not \
Expand All @@ -239,6 +250,28 @@ module TypeError = struct
"Only a single tuple of control parameters is expected."
else if n_args = 1 then "Did you mean to call the _tol version?"
else "Did you mean to call the _tol version with a tuple of these?")
| IlltypedLaplaceHessianBlockSize (name, None) ->
let info ppf () =
Fmt.pf ppf
"@[<hov>Missing the hessian block size (data-only %a) and \
remaining arguments.@]"
(expected_style UnsizedType.pp)
UInt in
generic_laplace_usage info ppf (name, [])
| IlltypedLaplaceHessianBlockSize (name, Some (DataOnly, ty)) ->
Fmt.pf ppf
"@[<hov>The hessian block size argument to %a must be a data-only \
%a.%a@]"
quoted name
(expected_style UnsizedType.pp)
UInt found_type ty
| IlltypedLaplaceHessianBlockSize (name, Some (_, ty)) ->
Fmt.pf ppf
"@[<hov>The hessian block size argument to %a must be a data-only \
%a.%a@ %a@]"
quoted name
(expected_style UnsizedType.pp)
UInt found_type ty SignatureMismatch.data_only_msg ()
| IlltypedLaplaceTolArgs (name, ArgNumMismatch (_, 0)) ->
Fmt.pf ppf
"Missing control parameter tuple at the end of the call to %a.@ \
Expand Down Expand Up @@ -777,6 +810,9 @@ let laplace_compatibility loc banned_function =
let illtyped_laplace_extra_args loc name args =
(loc, TypeError (TypeError.IlltypedLaplaceTooMany (name, args)))

let illtyped_laplace_hessian_block_size_arg loc name arg_ty =
(loc, TypeError (TypeError.IlltypedLaplaceHessianBlockSize (name, arg_ty)))

let illtyped_laplace_tolerance_args loc name mismatch =
(loc, TypeError (TypeError.IlltypedLaplaceTolArgs (name, mismatch)))

Expand Down
6 changes: 6 additions & 0 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ val illtyped_laplace_generic :
val laplace_compatibility : Location_span.t -> string -> t
val illtyped_laplace_extra_args : Location_span.t -> string -> int -> t

val illtyped_laplace_hessian_block_size_arg :
Location_span.t
-> string
-> (UnsizedType.autodifftype * UnsizedType.t) option
-> t

val illtyped_laplace_tolerance_args :
Location_span.t -> string -> SignatureMismatch.function_mismatch -> t

Expand Down
20 changes: 19 additions & 1 deletion src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,23 @@ and check_laplace_fn ~is_cond_dist loc cf tenv id tes =
(UnsizedType.ReturnType UReal) in
([lik_fun; lik_tupl], tes)
| _ -> generic_failure ~early:true () in
(* check hessian block size *)
let hbs_arg, rest =
let loc =
match List.last lik_args with
| Some e -> {e.emeta.loc with begin_loc= e.emeta.loc.end_loc}
| None -> loc in
match rest with
| hbs :: rest ->
let hbs_ty = arg_type hbs in
if hbs_ty <> UnsizedType.(DataOnly, UInt) then
Semantic_error.illtyped_laplace_hessian_block_size_arg hbs.emeta.loc
id.name (Some hbs_ty)
|> error
else (hbs, rest)
| _ ->
Semantic_error.illtyped_laplace_hessian_block_size_arg loc id.name None
|> error in
(* Check the remaining arguments: initial guess, covariance, and tolerances *)
match rest with
| {expr= Variable cov_fun; _} :: cov_tupl :: control_args ->
Expand All @@ -875,7 +892,8 @@ and check_laplace_fn ~is_cond_dist loc cf tenv id tes =
probably require two more calls to
[check_function_callable_with_tuple] *)
verify_laplace_control_args loc id control_args;
let args = lik_args @ (cov_fun_type :: cov_tupl :: control_args) in
let args =
lik_args @ (hbs_arg :: cov_fun_type :: cov_tupl :: control_args) in
let return_type =
if String.is_suffix id.name ~suffix:"_rng" then UnsizedType.UVector
else UnsizedType.UReal in
Expand Down
9 changes: 6 additions & 3 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ let rec wind_array_type = function
| typ, 0 -> typ
| typ, n -> wind_array_type (UArray typ, n - 1)

let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false

let rec pp ppf = function
| UInt -> Fmt.string ppf "int"
| UReal -> Fmt.string ppf "real"
Expand All @@ -91,7 +93,10 @@ let rec pp ppf = function

and pp_fun_arg ppf (ad_ty, unsized_ty) =
let open Fmt in
let pp_data = if' (equal_autodifftype ad_ty DataOnly) (any "data ") in
let pp_data =
if'
(equal_autodifftype ad_ty DataOnly && not (is_fun_type unsized_ty))
(any "data ") in
(pp_data ++ pp) ppf unsized_ty

and pp_returntype ppf = function
Expand Down Expand Up @@ -236,8 +241,6 @@ let is_eigen_type ut =
true
| _ -> false

let is_fun_type = function UFun _ | UMathLibraryFunction -> true | _ -> false

(** Detect if type contains an integer *)
let rec contains_int ut =
match ut with
Expand Down
11 changes: 11 additions & 0 deletions src/stan_math_signatures/Generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,17 @@ let () =
, ReturnType UReal
, [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix]
, AoS );
List.iter [UnsizedType.UInt; UVector] ~f:(fun t ->
add_unqualified
( "generate_laplace_options"
, ReturnType
(UTuple
[ UVector (* theta_0 *); UReal (* tolerance *)
; UInt (* max_num_steps *); UInt (* solver *)
; UInt (* max_steps_line_search *); UInt (* allow_fallthrough *)
])
, [t]
, AoS ));
add_unqualified
("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS);
add_unqualified
Expand Down
5 changes: 2 additions & 3 deletions src/stan_math_signatures/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ let laplace_helper_param_types name =
let laplace_tolerance_argument_types =
UnsizedType.
[ (AutoDiffable, UVector) (* theta_0 *); (DataOnly, UReal) (* tolerance *)
; (DataOnly, UInt) (* max_num_steps *)
; (DataOnly, UInt) (* hessian_block_size *); (DataOnly, UInt) (* solver *)
; (DataOnly, UInt) (* max_num_steps *); (DataOnly, UInt) (* solver *)
; (DataOnly, UInt) (* max_steps_line_search *)
; (DataOnly, UInt) (* allow_fallthrough *) ]

Expand All @@ -208,7 +207,7 @@ let is_special_function_name name =
|| is_embedded_laplace_fn name

let disallowed_second_order =
[ "algebra_solver"; "algebra_solver_newton"; "integrate_ode"
[ "algebra_solver"; "algebra_solver_newton"; "integrate_1d"; "integrate_ode"
; "integrate_ode_adams"; "integrate_ode_bdf"; "integrate_ode_rk45"; "map_rect"
; "hmm_marginal"; "hmm_hidden_state_prob" ]
|> String.Set.of_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ parameters {
}
model {

target += laplace_marginal(ll_function, (eta, log_ye, y),
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ functions {
array[] int y) {
// observed count
return neg_binomial_2_lpmf(y | exp(log_ye + theta), eta) +
// integrate 1d is itself allowed, actually
// integrate 1d SHOULD be allowed,
// see https://github.com/stan-dev/math/pull/2929
// but there is a bug: https://github.com/stan-dev/math/issues/3280
integrate_1d(integrand, 0, 1, y, y, y);
}

Expand Down Expand Up @@ -48,5 +49,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_callback1.stan
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ parameters {
real<lower=0> eta;
}
model {
target += laplace_marginal(ll_function, (eta, log_ye, y),
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_callback2.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ parameters {
real<lower=0> eta;
}
model {
target += laplace_marginal(ll_function_jacobian, (eta, log_ye, y),
target += laplace_marginal(ll_function_jacobian, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_callback3.stan
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ parameters {
}
model {
real ll_function;
target += laplace_marginal(ll_function, (eta, log_ye, y),
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_callback4.stan
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ parameters {
}
model {

target += laplace_marginal(ll_function, (eta, log_ye, y),
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_callback5.stan
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ parameters {
real<lower=0> eta;
}
model {
target += laplace_marginal(ll_function, (eta, log_ye, y),
target += laplace_marginal(ll_function, (eta, log_ye, y), 1,
K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward1.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ parameters {
model {

target += laplace_marginal(ll_function, {2.0},
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward2.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ parameters {
model {

target += laplace_marginal(ll_function, (eta, y, log_ye),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward3.stan
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ parameters {
}
model {
target += laplace_marginal(ll_function, (eta, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward4.stan
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ parameters {
}
model {
target += laplace_marginal(ll_function, (eta, log_ye, y),
K_function, (x, alpha, rho, rho, rho));
1, K_function, (x, alpha, rho, rho, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward5.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, {rho}));
1, K_function, (x, n_obs, alpha, {rho}));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward6.stan
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ parameters {

generated quantities {
vector[n_obs] theta = laplace_latent_rng(ll_function, (1,1, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
2 changes: 1 addition & 1 deletion test/integration/bad/embedded_laplace/bad_forward9.stan
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ parameters {
}
model {
target += laplace_marginal(ll_function, (eta, log_ye, y),
K_function, (x, n_obs, alpha, rho));
1, K_function, (x, n_obs, alpha, rho));
}
Loading