|
| 1 | +#ifndef STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP |
| 2 | +#define STAN_MATH_PRIM_PROB_WIENER4_LCCDF_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/prob/wiener4_lcdf.hpp> |
| 5 | + |
| 6 | +namespace stan { |
| 7 | +namespace math { |
| 8 | +namespace internal { |
| 9 | + |
| 10 | +/** |
| 11 | + * Log of probability of reaching the upper bound in diffusion process |
| 12 | + * |
| 13 | + * @tparam T_a type of boundary |
| 14 | + * @tparam T_w type of relative starting point |
| 15 | + * @tparam T_v type of drift rate |
| 16 | + * |
| 17 | + * @param a The boundary separation |
| 18 | + * @param w The relative starting point |
| 19 | + * @param v The drift rate |
| 20 | + * @return log probability to reach the upper bound |
| 21 | + */ |
| 22 | +template <typename T_a, typename T_w, typename T_v> |
| 23 | +inline auto log_wiener_prob_hit_upper(const T_a& a, const T_v& v, |
| 24 | + const T_w& w) { |
| 25 | + using ret_t = return_type_t<T_a, T_w, T_v>; |
| 26 | + const auto neg_v = -v; |
| 27 | + const auto one_m_w = 1.0 - w; |
| 28 | + if (fabs(v) == 0.0) { |
| 29 | + return ret_t(log(w)); |
| 30 | + } |
| 31 | + const auto exponent = 2.0 * v * a * w; |
| 32 | + // This branch is for numeric stability |
| 33 | + if (exponent < 0) { |
| 34 | + return ret_t(log1m_exp(exponent) |
| 35 | + - log_diff_exp(2.0 * neg_v * a * one_m_w, exponent)); |
| 36 | + } else { |
| 37 | + return ret_t(log1m_exp(-exponent) - log1m_exp(2.0 * neg_v * a)); |
| 38 | + } |
| 39 | +} |
| 40 | + |
| 41 | +/** |
| 42 | + * Calculate parts of the partial derivatives for wiener_prob_grad_a and |
| 43 | + * wiener_prob_grad_v (on log-scale) |
| 44 | + * |
| 45 | + * @tparam T_a type of boundary |
| 46 | + * @tparam T_w type of relative starting point |
| 47 | + * @tparam T_v type of drift rate |
| 48 | + * |
| 49 | + * @param a The boundary separation |
| 50 | + * @param w The relative starting point |
| 51 | + * @param v The drift rate |
| 52 | + * @return 'ans' term |
| 53 | + */ |
| 54 | +template <typename T_a, typename T_w, typename T_v> |
| 55 | +inline auto wiener_prob_derivative_term(const T_a& a, const T_v& v, |
| 56 | + const T_w& w) noexcept { |
| 57 | + using ret_t = return_type_t<T_a, T_w, T_v>; |
| 58 | + const auto exponent_m1 = log1m(1.1 * 1.0e-8); |
| 59 | + const auto neg_v = -v; |
| 60 | + const auto one_m_w = 1 - w; |
| 61 | + int sign_v = neg_v < 0 ? 1 : -1; |
| 62 | + const auto two_a_neg_v = 2.0 * a * neg_v; |
| 63 | + const auto exponent_with_1mw = sign_v * two_a_neg_v * w; |
| 64 | + const auto exponent = sign_v * two_a_neg_v; |
| 65 | + const auto exponent_with_w = two_a_neg_v * one_m_w; |
| 66 | + // truncating longer calculations, for numerical stability |
| 67 | + if (unlikely((exponent_with_1mw >= exponent_m1) |
| 68 | + || ((exponent_with_w >= exponent_m1) && (sign_v == 1)) |
| 69 | + || (exponent >= exponent_m1) || neg_v == 0)) { |
| 70 | + return ret_t(-one_m_w); |
| 71 | + } |
| 72 | + ret_t ans; |
| 73 | + ret_t diff_term; |
| 74 | + const auto log_w = log(one_m_w); |
| 75 | + if (neg_v < 0) { |
| 76 | + ans = LOG_TWO + exponent_with_1mw - log1m_exp(exponent_with_1mw); |
| 77 | + diff_term = log1m_exp(exponent_with_w) - log1m_exp(exponent); |
| 78 | + } else if (neg_v > 0) { |
| 79 | + ans = LOG_TWO - log1m_exp(exponent_with_1mw); |
| 80 | + diff_term = log_diff_exp(exponent_with_1mw, exponent) - log1m_exp(exponent); |
| 81 | + } |
| 82 | + if (log_w > diff_term) { |
| 83 | + ans = sign_v * exp(ans + log_diff_exp(log_w, diff_term)); |
| 84 | + } else { |
| 85 | + ans = -sign_v * exp(ans + log_diff_exp(diff_term, log_w)); |
| 86 | + } |
| 87 | + if (unlikely(!is_scal_finite(ans))) { |
| 88 | + return ret_t(NEGATIVE_INFTY); |
| 89 | + } |
| 90 | + return ans; |
| 91 | +} |
| 92 | + |
| 93 | +/** |
| 94 | + * Calculate wiener4 ccdf (natural-scale) |
| 95 | + * |
| 96 | + * @param y The reaction time in seconds |
| 97 | + * @param a The boundary separation |
| 98 | + * @param v The relative starting point |
| 99 | + * @param w The drift rate |
| 100 | + * @param log_err The log error tolerance in the computation of the number |
| 101 | + * of terms for the infinite sums |
| 102 | + * @return ccdf |
| 103 | + */ |
| 104 | +template <typename T_y, typename T_a, typename T_w, typename T_v, |
| 105 | + typename T_err> |
| 106 | +inline auto wiener4_ccdf(const T_y& y, const T_a& a, const T_v& v, const T_w& w, |
| 107 | + T_err log_err = log(1e-12)) noexcept { |
| 108 | + const auto prob_hit_upper = exp(log_wiener_prob_hit_upper(a, v, w)); |
| 109 | + const auto cdf |
| 110 | + = internal::wiener4_distribution<GradientCalc::ON>(y, a, v, w, log_err); |
| 111 | + return prob_hit_upper - cdf; |
| 112 | +} |
| 113 | + |
| 114 | +/** |
| 115 | + * Calculate derivative of the wiener4 ccdf w.r.t. 'a' (natural-scale) |
| 116 | + * |
| 117 | + * @param y The reaction time in seconds |
| 118 | + * @param a The boundary separation |
| 119 | + * @param v The relative starting point |
| 120 | + * @param w The drift rate |
| 121 | + * @param cdf The CDF value |
| 122 | + * @param log_err The log error tolerance in the computation of the number |
| 123 | + * of terms for the infinite sums |
| 124 | + * @return Gradient with respect to a |
| 125 | + */ |
| 126 | +template <typename T_y, typename T_a, typename T_w, typename T_v, |
| 127 | + typename T_cdf, typename T_err> |
| 128 | +inline auto wiener4_ccdf_grad_a(const T_y& y, const T_a& a, const T_v& v, |
| 129 | + const T_w& w, T_cdf&& cdf, |
| 130 | + T_err log_err = log(1e-12)) noexcept { |
| 131 | + using ret_t = return_type_t<T_a, T_w, T_v>; |
| 132 | + |
| 133 | + // derivative of the wiener probability w.r.t. 'a' (on log-scale) |
| 134 | + auto prob_grad_a = -wiener_prob_derivative_term(a, v, w) * v; |
| 135 | + if (!is_scal_finite(prob_grad_a)) { |
| 136 | + prob_grad_a = ret_t(NEGATIVE_INFTY); |
| 137 | + } |
| 138 | + const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w); |
| 139 | + const auto cdf_grad_a = wiener4_cdf_grad_a(y, a, v, w, cdf, log_err); |
| 140 | + return prob_grad_a * exp(log_prob_hit_upper) - cdf_grad_a; |
| 141 | +} |
| 142 | + |
| 143 | +/** |
| 144 | + * Calculate derivative of the wiener4 ccdf w.r.t. 'v' (natural-scale) |
| 145 | + * |
| 146 | + * @param y The reaction time in seconds |
| 147 | + * @param a The boundary separation |
| 148 | + * @param v The relative starting point |
| 149 | + * @param w The drift rate |
| 150 | + * @param cdf The CDF value |
| 151 | + * @param log_err The log error tolerance in the computation of the number |
| 152 | + * of terms for the infinite sums |
| 153 | + * @return Gradient with respect to v |
| 154 | + */ |
| 155 | +template <typename T_y, typename T_a, typename T_w, typename T_v, |
| 156 | + typename T_cdf, typename T_err> |
| 157 | +inline auto wiener4_ccdf_grad_v(const T_y& y, const T_a& a, const T_v& v, |
| 158 | + const T_w& w, T_cdf&& cdf, |
| 159 | + T_err log_err = log(1e-12)) noexcept { |
| 160 | + using ret_t = return_type_t<T_a, T_w, T_v>; |
| 161 | + const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w); |
| 162 | + // derivative of the wiener probability w.r.t. 'v' (on log-scale) |
| 163 | + auto prob_grad_v = -wiener_prob_derivative_term(a, v, w) * a; |
| 164 | + if (!is_scal_finite(fabs(prob_grad_v))) { |
| 165 | + prob_grad_v = ret_t(NEGATIVE_INFTY); |
| 166 | + } |
| 167 | + |
| 168 | + const auto cdf_grad_v = wiener4_cdf_grad_v(y, a, v, w, cdf, log_err); |
| 169 | + return prob_grad_v * exp(log_prob_hit_upper) - cdf_grad_v; |
| 170 | +} |
| 171 | + |
| 172 | +/** |
| 173 | + * Calculate derivative of the wiener4 ccdf w.r.t. 'w' (natural-scale) |
| 174 | + * |
| 175 | + * @param y The reaction time in seconds |
| 176 | + * @param a The boundary separation |
| 177 | + * @param v The relative starting point |
| 178 | + * @param w The drift rate |
| 179 | + * @param cdf The CDF value |
| 180 | + * @param log_err The log error tolerance in the computation of the number |
| 181 | + * of terms for the infinite sums |
| 182 | + * @return Gradient with respect to w |
| 183 | + */ |
| 184 | +template <typename T_y, typename T_a, typename T_w, typename T_v, |
| 185 | + typename T_cdf, typename T_err> |
| 186 | +inline auto wiener4_ccdf_grad_w(const T_y& y, const T_a& a, const T_v& v, |
| 187 | + const T_w& w, T_cdf&& cdf, |
| 188 | + T_err log_err = log(1e-12)) noexcept { |
| 189 | + using ret_t = return_type_t<T_a, T_w, T_v>; |
| 190 | + const auto log_prob_hit_upper = log_wiener_prob_hit_upper(a, v, w); |
| 191 | + // derivative of the wiener probability w.r.t. 'v' (on log-scale) |
| 192 | + const auto exponent = -sign(v) * 2.0 * v * a * w; |
| 193 | + auto prob_grad_w |
| 194 | + = (v != 0) ? exp(LOG_TWO + log(fabs(v)) + log(a) - log1m_exp(exponent)) |
| 195 | + : ret_t(1 / w); |
| 196 | + if (v > 0) { |
| 197 | + prob_grad_w *= exp(exponent); |
| 198 | + } |
| 199 | + |
| 200 | + const auto cdf_grad_w = wiener4_cdf_grad_w(y, a, v, w, cdf, log_err); |
| 201 | + return prob_grad_w * exp(log_prob_hit_upper) - cdf_grad_w; |
| 202 | +} |
| 203 | + |
| 204 | +} // namespace internal |
| 205 | + |
| 206 | +/** |
| 207 | + * Log-CCDF for the 4-parameter Wiener distribution. |
| 208 | + * See 'wiener_full_lpdf' for more comprehensive documentation. |
| 209 | + * |
| 210 | + * @tparam T_y type of reaction time |
| 211 | + * @tparam T_a type of boundary |
| 212 | + * @tparam T_t0 type of non-decision time |
| 213 | + * @tparam T_w type of relative starting point |
| 214 | + * @tparam T_v type of drift rate |
| 215 | + * |
| 216 | + * @param y The reaction time in seconds |
| 217 | + * @param a The boundary separation |
| 218 | + * @param t0 The non-decision time |
| 219 | + * @param w The relative starting point |
| 220 | + * @param v The drift rate |
| 221 | + * @param precision_derivatives Level of precision in estimation |
| 222 | + * @return The log of the Wiener first passage time distribution with |
| 223 | + * the specified arguments for upper boundary responses |
| 224 | + */ |
| 225 | +template <bool propto = false, typename T_y, typename T_a, typename T_t0, |
| 226 | + typename T_w, typename T_v> |
| 227 | +inline auto wiener_lccdf(const T_y& y, const T_a& a, const T_t0& t0, |
| 228 | + const T_w& w, const T_v& v, |
| 229 | + const double& precision_derivatives) { |
| 230 | + using T_partials_return = partials_return_t<T_y, T_a, T_t0, T_w, T_v>; |
| 231 | + using ret_t = return_type_t<T_y, T_a, T_t0, T_w, T_v>; |
| 232 | + using T_y_ref = ref_type_if_t<!is_constant<T_y>::value, T_y>; |
| 233 | + using T_a_ref = ref_type_if_t<!is_constant<T_a>::value, T_a>; |
| 234 | + using T_t0_ref = ref_type_if_t<!is_constant<T_t0>::value, T_t0>; |
| 235 | + using T_w_ref = ref_type_if_t<!is_constant<T_w>::value, T_w>; |
| 236 | + using T_v_ref = ref_type_if_t<!is_constant<T_v>::value, T_v>; |
| 237 | + using internal::GradientCalc; |
| 238 | + |
| 239 | + T_y_ref y_ref = y; |
| 240 | + T_a_ref a_ref = a; |
| 241 | + T_t0_ref t0_ref = t0; |
| 242 | + T_w_ref w_ref = w; |
| 243 | + T_v_ref v_ref = v; |
| 244 | + |
| 245 | + auto y_val = to_ref(as_value_column_array_or_scalar(y_ref)); |
| 246 | + auto a_val = to_ref(as_value_column_array_or_scalar(a_ref)); |
| 247 | + auto v_val = to_ref(as_value_column_array_or_scalar(v_ref)); |
| 248 | + auto w_val = to_ref(as_value_column_array_or_scalar(w_ref)); |
| 249 | + auto t0_val = to_ref(as_value_column_array_or_scalar(t0_ref)); |
| 250 | + |
| 251 | + static constexpr const char* function_name = "wiener4_lccdf"; |
| 252 | + if (size_zero(y, a, t0, w, v)) { |
| 253 | + return ret_t(0.0); |
| 254 | + } |
| 255 | + |
| 256 | + if (!include_summand<propto, T_y, T_a, T_t0, T_w, T_v>::value) { |
| 257 | + return ret_t(0.0); |
| 258 | + } |
| 259 | + |
| 260 | + check_consistent_sizes(function_name, "Random variable", y, |
| 261 | + "Boundary separation", a, "Drift rate", v, |
| 262 | + "A-priori bias", w, "Nondecision time", t0); |
| 263 | + check_positive_finite(function_name, "Random variable", y_val); |
| 264 | + check_positive_finite(function_name, "Boundary separation", a_val); |
| 265 | + check_finite(function_name, "Drift rate", v_val); |
| 266 | + check_less(function_name, "A-priori bias", w_val, 1); |
| 267 | + check_greater(function_name, "A-priori bias", w_val, 0); |
| 268 | + check_nonnegative(function_name, "Nondecision time", t0_val); |
| 269 | + check_finite(function_name, "Nondecision time", t0_val); |
| 270 | + |
| 271 | + const size_t N = max_size(y, a, t0, w, v); |
| 272 | + |
| 273 | + scalar_seq_view<T_y_ref> y_vec(y_ref); |
| 274 | + scalar_seq_view<T_a_ref> a_vec(a_ref); |
| 275 | + scalar_seq_view<T_t0_ref> t0_vec(t0_ref); |
| 276 | + scalar_seq_view<T_w_ref> w_vec(w_ref); |
| 277 | + scalar_seq_view<T_v_ref> v_vec(v_ref); |
| 278 | + const size_t N_y_t0 = max_size(y, t0); |
| 279 | + |
| 280 | + for (size_t i = 0; i < N_y_t0; ++i) { |
| 281 | + if (y_vec[i] <= t0_vec[i]) { |
| 282 | + std::stringstream msg; |
| 283 | + msg << ", but must be greater than nondecision time = " << t0_vec[i]; |
| 284 | + std::string msg_str(msg.str()); |
| 285 | + throw_domain_error(function_name, "Random variable", y_vec[i], " = ", |
| 286 | + msg_str.c_str()); |
| 287 | + } |
| 288 | + } |
| 289 | + |
| 290 | + // for precs. 1e-6, 1e-12, see Hartmann et al. (2021), Henrich et al. (2023) |
| 291 | + const auto log_error_cdf = log(1e-6); |
| 292 | + const auto log_error_derivative = log(precision_derivatives); |
| 293 | + const T_partials_return log_error_absolute = log(1e-12); |
| 294 | + T_partials_return lccdf = 0.0; |
| 295 | + auto ops_partials |
| 296 | + = make_partials_propagator(y_ref, a_ref, t0_ref, w_ref, v_ref); |
| 297 | + |
| 298 | + const double LOG_FOUR = std::log(4.0); |
| 299 | + |
| 300 | + // calculate distribution and partials |
| 301 | + for (size_t i = 0; i < N; i++) { |
| 302 | + const auto y_value = y_vec.val(i); |
| 303 | + const auto a_value = a_vec.val(i); |
| 304 | + const auto t0_value = t0_vec.val(i); |
| 305 | + const auto w_value = w_vec.val(i); |
| 306 | + const auto v_value = v_vec.val(i); |
| 307 | + |
| 308 | + const T_partials_return cdf |
| 309 | + = internal::estimate_with_err_check<4, 0, GradientCalc::OFF, |
| 310 | + GradientCalc::OFF>( |
| 311 | + [](auto&&... args) { |
| 312 | + return internal::wiener4_distribution<GradientCalc::ON>(args...); |
| 313 | + }, |
| 314 | + log_error_cdf - LOG_TWO, y_value - t0_value, a_value, v_value, |
| 315 | + w_value, log_error_absolute); |
| 316 | + |
| 317 | + const auto prob_hit_upper |
| 318 | + = exp(internal::log_wiener_prob_hit_upper(a_value, v_value, w_value)); |
| 319 | + const auto ccdf = prob_hit_upper - cdf; |
| 320 | + const auto log_ccdf_single_value = log(ccdf); |
| 321 | + |
| 322 | + lccdf += log_ccdf_single_value; |
| 323 | + |
| 324 | + const auto new_est_err |
| 325 | + = log_ccdf_single_value + log_error_derivative - LOG_FOUR; |
| 326 | + |
| 327 | + if (!is_constant_all<T_y>::value || !is_constant_all<T_t0>::value) { |
| 328 | + const auto deriv_y = internal::estimate_with_err_check<5, 0>( |
| 329 | + [](auto&&... args) { |
| 330 | + return internal::wiener5_density<GradientCalc::ON>(args...); |
| 331 | + }, |
| 332 | + new_est_err, y_value - t0_value, a_value, v_value, w_value, 0.0, |
| 333 | + log_error_absolute); |
| 334 | + if (!is_constant_all<T_y>::value) { |
| 335 | + partials<0>(ops_partials)[i] = -deriv_y / ccdf; |
| 336 | + } |
| 337 | + if (!is_constant_all<T_t0>::value) { |
| 338 | + partials<2>(ops_partials)[i] = deriv_y / ccdf; |
| 339 | + } |
| 340 | + } |
| 341 | + if (!is_constant_all<T_a>::value) { |
| 342 | + partials<1>(ops_partials)[i] |
| 343 | + = internal::estimate_with_err_check<5, 0>( |
| 344 | + [](auto&&... args) { |
| 345 | + return internal::wiener4_ccdf_grad_a(args...); |
| 346 | + }, |
| 347 | + new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf, |
| 348 | + log_error_absolute) |
| 349 | + / ccdf; |
| 350 | + } |
| 351 | + if (!is_constant_all<T_w>::value) { |
| 352 | + partials<3>(ops_partials)[i] |
| 353 | + = internal::estimate_with_err_check<5, 0>( |
| 354 | + [](auto&&... args) { |
| 355 | + return internal::wiener4_ccdf_grad_w(args...); |
| 356 | + }, |
| 357 | + new_est_err, y_value - t0_value, a_value, v_value, w_value, cdf, |
| 358 | + log_error_absolute) |
| 359 | + / ccdf; |
| 360 | + } |
| 361 | + if (!is_constant_all<T_v>::value) { |
| 362 | + partials<4>(ops_partials)[i] |
| 363 | + = internal::wiener4_ccdf_grad_v(y_value - t0_value, a_value, v_value, |
| 364 | + w_value, cdf, log_error_absolute) |
| 365 | + / ccdf; |
| 366 | + } |
| 367 | + } // for loop |
| 368 | + return ops_partials.build(lccdf); |
| 369 | +} |
| 370 | +} // namespace math |
| 371 | +} // namespace stan |
| 372 | +#endif |
0 commit comments