Skip to content

Commit b0e4883

Browse files
Merge pull request #3042 from Franzi2114/feature/issue-2966-Add-7-parameter-DDM-CDF-and-CCDF
Feature/issue 2966 add 7 parameter ddm cdf and ccdf
2 parents 85c147e + e190fe0 commit b0e4883

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+4878
-284
lines changed

stan/math/prim/prob.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,11 @@
311311
#include <stan/math/prim/prob/weibull_rng.hpp>
312312
#include <stan/math/prim/prob/wiener5_lpdf.hpp>
313313
#include <stan/math/prim/prob/wiener_lpdf.hpp>
314+
#include <stan/math/prim/prob/wiener4_lcdf.hpp>
315+
#include <stan/math/prim/prob/wiener4_lccdf.hpp>
314316
#include <stan/math/prim/prob/wiener_full_lpdf.hpp>
317+
#include <stan/math/prim/prob/wiener_full_lcdf.hpp>
318+
#include <stan/math/prim/prob/wiener_full_lccdf.hpp>
315319
#include <stan/math/prim/prob/wishart_cholesky_lpdf.hpp>
316320
#include <stan/math/prim/prob/wishart_cholesky_rng.hpp>
317321
#include <stan/math/prim/prob/wishart_lpdf.hpp>
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
#ifndef STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP
2+
#define STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP
3+
4+
#include <stan/math/prim/prob/wiener4_lcdf.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
namespace internal {
9+
10+
/**
11+
* Log of probability of reaching the upper bound in diffusion process
12+
*
13+
* @tparam T_a type of boundary
14+
* @tparam T_w type of relative starting point
15+
* @tparam T_v type of drift rate
16+
*
17+
* @param a The boundary separation
18+
* @param w The relative starting point
19+
* @param v The drift rate
20+
* @return log probability to reach the upper bound
21+
*/
22+
template <typename T_a, typename T_w, typename T_v>
23+
inline auto log_wiener_prob_hit_upper(const T_a& a, const T_v& v,
24+
const T_w& w) {
25+
using ret_t = return_type_t<T_a, T_w, T_v>;
26+
const auto neg_v = -v;
27+
const auto one_m_w = 1.0 - w;
28+
if (fabs(v) == 0.0) {
29+
return ret_t(log(w));
30+
}
31+
const auto exponent = 2.0 * v * a * w;
32+
// This branch is for numeric stability
33+
if (exponent < 0) {
34+
return ret_t(log1m_exp(exponent)
35+
- log_diff_exp(2.0 * neg_v * a * one_m_w, exponent));
36+
} else {
37+
return ret_t(log1m_exp(-exponent) - log1m_exp(2.0 * neg_v * a));
38+
}
39+
}
40+
41+
/**
42+
* Calculate parts of the partial derivatives for wiener_prob_grad_a and
43+
* wiener_prob_grad_v (on log-scale)
44+
*
45+
* @tparam T_a type of boundary
46+
* @tparam T_w type of relative starting point
47+
* @tparam T_v type of drift rate
48+
*
49+
* @param a The boundary separation
50+
* @param w The relative starting point
51+
* @param v The drift rate
52+
* @return 'ans' term
53+
*/
54+
template <typename T_a, typename T_w, typename T_v>
55+
inline auto wiener_prob_derivative_term(const T_a& a, const T_v& v,
56+
const T_w& w) noexcept {
57+
using ret_t = return_type_t<T_a, T_w, T_v>;
58+
const auto exponent_m1 = log1m(1.1 * 1.0e-8);
59+
const auto neg_v = -v;
60+
const auto one_m_w = 1 - w;
61+
int sign_v = neg_v < 0 ? 1 : -1;
62+
const auto two_a_neg_v = 2.0 * a * neg_v;
63+
const auto exponent_with_1mw = sign_v * two_a_neg_v * w;
64+
const auto exponent = sign_v * two_a_neg_v;
65+
const auto exponent_with_w = two_a_neg_v * one_m_w;
66+
// truncating longer calculations, for numerical stability
67+
if (unlikely((exponent_with_1mw >= exponent_m1)
68+
|| ((exponent_with_w >= exponent_m1) && (sign_v == 1))
69+
|| (exponent >= exponent_m1) || neg_v == 0)) {
70+
return ret_t(-one_m_w);
71+
}
72+
ret_t ans;
73+
ret_t diff_term;
74+
const auto log_w = log(one_m_w);
75+
if (neg_v < 0) {
76+
ans = LOG_TWO + exponent_with_1mw - log1m_exp(exponent_with_1mw);
77+
diff_term = log1m_exp(exponent_with_w) - log1m_exp(exponent);
78+
} else if (neg_v > 0) {
79+
ans = LOG_TWO - log1m_exp(exponent_with_1mw);
80+
diff_term = log_diff_exp(exponent_with_1mw, exponent) - log1m_exp(exponent);
81+
}
82+
if (log_w > diff_term) {
83+
ans = sign_v * exp(ans + log_diff_exp(log_w, diff_term));
84+
} else {
85+
ans = -sign_v * exp(ans + log_diff_exp(diff_term, log_w));
86+
}
87+
if (unlikely(!is_scal_finite(ans))) {
88+
return ret_t(NEGATIVE_INFTY);
89+
}
90+
return ans;
91+
}
92+
93+
/**
94+
* Calculate wiener4 ccdf (natural-scale)
95+
*
96+
* @param y The reaction time in seconds
97+
* @param a The boundary separation
98+
* @param v The relative starting point
99+
* @param w The drift rate
100+
* @param log_err The log error tolerance in the computation of the number
101+
* of terms for the infinite sums
102+
* @return ccdf
103+
*/
104+
template <typename T_y, typename T_a, typename T_w, typename T_v,
105+
typename T_err>
106+
inline auto wiener4_ccdf(const T_y& y, const T_a& a, const T_v& v, const T_w& w,
107+
T_err log_err = log(1e-12)) noexcept {
108+
const auto prob_hit_upper = exp(log_wiener_prob_hit_upper(a, v, w));
109+
const auto cdf
110+
= internal::wiener4_distribution<GradientCalc::ON>(y, a, v, w, log_err);
111+
return prob_hit_upper - cdf;
112+
}
113+
114+
/**
115+
* Calculate derivative of the wiener4 ccdf w.r.t. 'a' (natural-scale)
116+
*
117+
* @param y The reaction time in seconds
118+
* @param a The boundary separation
119+
* @param v The relative starting point
120+
* @param w The drift rate
121+
* @param cdf The CDF value
122+
* @param log_err The log error tolerance in the computation of the number
123+
* of terms for the infinite sums
124+
* @return Gradient with respect to a
125+
*/
126+
template <typename T_y, typename T_a, typename T_w, typename T_v,
127+
typename T_cdf, typename T_err>
128+
inline auto wiener4_ccdf_grad_a(const T_y& y, const T_a& a, const T_v& v,
129+
const T_w& w, T_cdf&& cdf,
130+
T_err log_err = log(1e-12)) noexcept {
131+
using ret_t = return_type_t<T_a, T_w, T_v>;
132+
133+
// derivative of the wiener probability w.r.t. 'a' (on log-scale)
134+
auto prob_grad_a = -wiener_prob_derivative_term(a, v, w) * v;
135+
if (!is_scal_finite(prob_grad_a)) {
136+
prob_grad_a = ret_t(NEGATIVE_INFTY);
137+
}
138+
const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
139+
const auto cdf_grad_a = wiener4_cdf_grad_a(y, a, v, w, cdf, log_err);
140+
return prob_grad_a * exp(log_prob_hit_upper) - cdf_grad_a;
141+
}
142+
143+
/**
144+
* Calculate derivative of the wiener4 ccdf w.r.t. 'v' (natural-scale)
145+
*
146+
* @param y The reaction time in seconds
147+
* @param a The boundary separation
148+
* @param v The relative starting point
149+
* @param w The drift rate
150+
* @param cdf The CDF value
151+
* @param log_err The log error tolerance in the computation of the number
152+
* of terms for the infinite sums
153+
* @return Gradient with respect to v
154+
*/
155+
template <typename T_y, typename T_a, typename T_w, typename T_v,
156+
typename T_cdf, typename T_err>
157+
inline auto wiener4_ccdf_grad_v(const T_y& y, const T_a& a, const T_v& v,
158+
const T_w& w, T_cdf&& cdf,
159+
T_err log_err = log(1e-12)) noexcept {
160+
using ret_t = return_type_t<T_a, T_w, T_v>;
161+
const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
162+
// derivative of the wiener probability w.r.t. 'v' (on log-scale)
163+
auto prob_grad_v = -wiener_prob_derivative_term(a, v, w) * a;
164+
if (!is_scal_finite(fabs(prob_grad_v))) {
165+
prob_grad_v = ret_t(NEGATIVE_INFTY);
166+
}
167+
168+
const auto cdf_grad_v = wiener4_cdf_grad_v(y, a, v, w, cdf, log_err);
169+
return prob_grad_v * exp(log_prob_hit_upper) - cdf_grad_v;
170+
}
171+
172+
/**
173+
* Calculate derivative of the wiener4 ccdf w.r.t. 'w' (natural-scale)
174+
*
175+
* @param y The reaction time in seconds
176+
* @param a The boundary separation
177+
* @param v The relative starting point
178+
* @param w The drift rate
179+
* @param cdf The CDF value
180+
* @param log_err The log error tolerance in the computation of the number
181+
* of terms for the infinite sums
182+
* @return Gradient with respect to w
183+
*/
184+
template <typename T_y, typename T_a, typename T_w, typename T_v,
185+
typename T_cdf, typename T_err>
186+
inline auto wiener4_ccdf_grad_w(const T_y& y, const T_a& a, const T_v& v,
187+
const T_w& w, T_cdf&& cdf,
188+
T_err log_err = log(1e-12)) noexcept {
189+
using ret_t = return_type_t<T_a, T_w, T_v>;
190+
const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w);
191+
// derivative of the wiener probability w.r.t. 'v' (on log-scale)
192+
const auto exponent = -sign(v) * 2.0 * v * a * w;
193+
auto prob_grad_w
194+
= (v != 0) ? exp(LOG_TWO + log(fabs(v)) + log(a) - log1m_exp(exponent))
195+
: ret_t(1 / w);
196+
if (v > 0) {
197+
prob_grad_w *= exp(exponent);
198+
}
199+
200+
const auto cdf_grad_w = wiener4_cdf_grad_w(y, a, v, w, cdf, log_err);
201+
return prob_grad_w * exp(log_prob_hit_upper) - cdf_grad_w;
202+
}
203+
204+
} // namespace internal
205+
206+
/**
207+
* Log-CCDF for the 4-parameter Wiener distribution.
208+
* See 'wiener_full_lpdf' for more comprehensive documentation.
209+
*
210+
* @tparam T_y type of reaction time
211+
* @tparam T_a type of boundary
212+
* @tparam T_t0 type of non-decision time
213+
* @tparam T_w type of relative starting point
214+
* @tparam T_v type of drift rate
215+
*
216+
* @param y The reaction time in seconds
217+
* @param a The boundary separation
218+
* @param t0 The non-decision time
219+
* @param w The relative starting point
220+
* @param v The drift rate
221+
* @param precision_derivatives Level of precision in estimation
222+
* @return The log of the Wiener first passage time distribution with
223+
* the specified arguments for upper boundary responses
224+
*/
225+
template <bool propto = false, typename T_y, typename T_a, typename T_t0,
226+
typename T_w, typename T_v>
227+
inline auto wiener_lccdf(const T_y& y, const T_a& a, const T_t0& t0,
228+
const T_w& w, const T_v& v,
229+
const double& precision_derivatives) {
230+
using T_partials_return = partials_return_t<T_y, T_a, T_t0, T_w, T_v>;
231+
using ret_t = return_type_t<T_y, T_a, T_t0, T_w, T_v>;
232+
using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>;
233+
using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>;
234+
using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>;
235+
using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>;
236+
using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>;
237+
using internal::GradientCalc;
238+
239+
T_y_ref y_ref = y;
240+
T_a_ref a_ref = a;
241+
T_t0_ref t0_ref = t0;
242+
T_w_ref w_ref = w;
243+
T_v_ref v_ref = v;
244+
245+
auto y_val = to_ref(as_value_column_array_or_scalar(y_ref));
246+
auto a_val = to_ref(as_value_column_array_or_scalar(a_ref));
247+
auto v_val = to_ref(as_value_column_array_or_scalar(v_ref));
248+
auto w_val = to_ref(as_value_column_array_or_scalar(w_ref));
249+
auto t0_val = to_ref(as_value_column_array_or_scalar(t0_ref));
250+
251+
static constexpr const char* function_name = "wiener4_lccdf";
252+
if (size_zero(y, a, t0, w, v)) {
253+
return ret_t(0.0);
254+
}
255+
256+
if (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v>::value) {
257+
return ret_t(0.0);
258+
}
259+
260+
check_consistent_sizes(function_name, "Random variable", y,
261+
"Boundary separation", a, "Drift rate", v,
262+
"A-priori bias", w, "Nondecision time", t0);
263+
check_positive_finite(function_name, "Random variable", y_val);
264+
check_positive_finite(function_name, "Boundary separation", a_val);
265+
check_finite(function_name, "Drift rate", v_val);
266+
check_less(function_name, "A-priori bias", w_val, 1);
267+
check_greater(function_name, "A-priori bias", w_val, 0);
268+
check_nonnegative(function_name, "Nondecision time", t0_val);
269+
check_finite(function_name, "Nondecision time", t0_val);
270+
271+
const size_t N = max_size(y, a, t0, w, v);
272+
273+
scalar_seq_view<T_y_ref> y_vec(y_ref);
274+
scalar_seq_view<T_a_ref> a_vec(a_ref);
275+
scalar_seq_view<T_t0_ref> t0_vec(t0_ref);
276+
scalar_seq_view<T_w_ref> w_vec(w_ref);
277+
scalar_seq_view<T_v_ref> v_vec(v_ref);
278+
const size_t N_y_t0 = max_size(y, t0);
279+
280+
for (size_t i = 0; i < N_y_t0; ++i) {
281+
if (y_vec[i] <= t0_vec[i]) {
282+
std::stringstream msg;
283+
msg << ", but must be greater than nondecision time = " << t0_vec[i];
284+
std::string msg_str(msg.str());
285+
throw_domain_error(function_name, "Random variable", y_vec[i], " = ",
286+
msg_str.c_str());
287+
}
288+
}
289+
290+
// for precs. 1e-6, 1e-12, see Hartmann et al. (2021), Henrich et al. (2023)
291+
const auto log_error_cdf = log(1e-6);
292+
const auto log_error_derivative = log(precision_derivatives);
293+
const T_partials_return log_error_absolute = log(1e-12);
294+
T_partials_return lccdf = 0.0;
295+
auto ops_partials
296+
= make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref);
297+
298+
const double LOG_FOUR = std::log(4.0);
299+
300+
// calculate distribution and partials
301+
for (size_t i = 0; i < N; i++) {
302+
const auto y_value = y_vec.val(i);
303+
const auto a_value = a_vec.val(i);
304+
const auto t0_value = t0_vec.val(i);
305+
const auto w_value = w_vec.val(i);
306+
const auto v_value = v_vec.val(i);
307+
308+
const T_partials_return cdf
309+
= internal::estimate_with_err_check<4, 0, GradientCalc::OFF,
310+
GradientCalc::OFF>(
311+
[](auto&&... args) {
312+
return internal::wiener4_distribution<GradientCalc::ON>(args...);
313+
},
314+
log_error_cdf - LOG_TWO, y_value - t0_value, a_value, v_value,
315+
w_value, log_error_absolute);
316+
317+
const auto prob_hit_upper
318+
= exp(internal::log_wiener_prob_hit_upper(a_value, v_value, w_value));
319+
const auto ccdf = prob_hit_upper - cdf;
320+
const auto log_ccdf_single_value = log(ccdf);
321+
322+
lccdf += log_ccdf_single_value;
323+
324+
const auto new_est_err
325+
= log_ccdf_single_value + log_error_derivative - LOG_FOUR;
326+
327+
if (!is_constant_all<T_y>::value || !is_constant_all<T_t0>::value) {
328+
const auto deriv_y = internal::estimate_with_err_check<5, 0>(
329+
[](auto&&... args) {
330+
return internal::wiener5_density<GradientCalc::ON>(args...);
331+
},
332+
new_est_err, y_value - t0_value, a_value, v_value, w_value, 0.0,
333+
log_error_absolute);
334+
if (!is_constant_all<T_y>::value) {
335+
partials<0>(ops_partials)[i] = -deriv_y / ccdf;
336+
}
337+
if (!is_constant_all<T_t0>::value) {
338+
partials<2>(ops_partials)[i] = deriv_y / ccdf;
339+
}
340+
}
341+
if (!is_constant_all<T_a>::value) {
342+
partials<1>(ops_partials)[i]
343+
= internal::estimate_with_err_check<5, 0>(
344+
[](auto&&... args) {
345+
return internal::wiener4_ccdf_grad_a(args...);
346+
},
347+
new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
348+
log_error_absolute)
349+
/ ccdf;
350+
}
351+
if (!is_constant_all<T_w>::value) {
352+
partials<3>(ops_partials)[i]
353+
= internal::estimate_with_err_check<5, 0>(
354+
[](auto&&... args) {
355+
return internal::wiener4_ccdf_grad_w(args...);
356+
},
357+
new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf,
358+
log_error_absolute)
359+
/ ccdf;
360+
}
361+
if (!is_constant_all<T_v>::value) {
362+
partials<4>(ops_partials)[i]
363+
= internal::wiener4_ccdf_grad_v(y_value - t0_value, a_value, v_value,
364+
w_value, cdf, log_error_absolute)
365+
/ ccdf;
366+
}
367+
} // for loop
368+
return ops_partials.build(lccdf);
369+
}
370+
} // namespace math
371+
} // namespace stan
372+
#endif

0 commit comments

Comments
 (0)