Skip to content

Commit 12bf37c

Browse files
committed
Deploying to r-dev from @ 53156f1 🚀
1 parent d6b018c commit 12bf37c

File tree

7 files changed

+125
-90
lines changed

7 files changed

+125
-90
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: stochtree
22
Title: Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference
3-
Version: 0.2.0.9000
3+
Version: 0.2.1
44
Authors@R:
55
c(
66
person("Drew", "Herren", email = "[email protected]", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),

NEWS.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
# stochtree (development version)
2-
3-
## New Features
4-
5-
## Computational Improvements
1+
# stochtree 0.2.1
62

73
## Bug Fixes
84

9-
* Predict random effects correctly in R for univariate random effects models ([#248](https://github.com/StochasticTree/stochtree/pull/248))
10-
11-
## Documentation Improvements
5+
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
126

137
## Other Changes
148

9+
* Encode expectations about which combinations of BART / BCF features work together and ensure warning ([#250](https://github.com/StochasticTree/stochtree/pull/250))
10+
1511
# stochtree 0.2.0
1612

1713
## New Features

R/bart.R

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,16 @@ bart <- function(
835835
}
836836
}
837837

838+
# Runtime checks for variance forest
839+
if (include_variance_forest) {
840+
if (sample_sigma2_global) {
841+
warning(
842+
"Global error variance will not be sampled with a heteroskedasticity forest"
843+
)
844+
sample_sigma2_global <- F
845+
}
846+
}
847+
838848
# Handle standardization, prior calibration, and initialization of forest
839849
# differently for binary and continuous outcomes
840850
if (probit_outcome_model) {
@@ -2124,7 +2134,6 @@ predict.bartmodel <- function(
21242134
X <- preprocessPredictionData(X, train_set_metadata)
21252135

21262136
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
2127-
has_rfx <- FALSE
21282137
if (predict_rfx) {
21292138
if (!is.null(rfx_group_ids)) {
21302139
rfx_unique_group_ids <- object$rfx_unique_group_ids
@@ -2135,7 +2144,6 @@ predict.bartmodel <- function(
21352144
)
21362145
}
21372146
rfx_group_ids <- as.integer(group_ids_factor)
2138-
has_rfx <- TRUE
21392147
}
21402148
}
21412149

R/bcf.R

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -897,14 +897,7 @@ bcf <- function(
897897
# Handle multivariate treatment
898898
has_multivariate_treatment <- ncol(Z_train) > 1
899899
if (has_multivariate_treatment) {
900-
# Disable adaptive coding, internal propensity model, and
901-
# leaf scale sampling if treatment is multivariate
902-
if (adaptive_coding) {
903-
warning(
904-
"Adaptive coding is incompatible with multivariate treatment and will be ignored"
905-
)
906-
adaptive_coding <- FALSE
907-
}
900+
# Disable internal propensity model and leaf scale sampling if treatment is multivariate
908901
if (is.null(propensity_train)) {
909902
if (propensity_covariate != "none") {
910903
warning(
@@ -949,21 +942,31 @@ bcf <- function(
949942
}
950943
has_basis_rfx <- TRUE
951944
num_basis_rfx <- ncol(rfx_basis_train)
952-
} else if (rfx_model_spec == "intercept_only") {
953-
rfx_basis_train <- matrix(
954-
rep(1, nrow(X_train)),
955-
nrow = nrow(X_train),
956-
ncol = 1
957-
)
958-
has_basis_rfx <- TRUE
959-
num_basis_rfx <- 1
960945
} else if (rfx_model_spec == "intercept_plus_treatment") {
961-
rfx_basis_train <- cbind(
962-
rep(1, nrow(X_train)),
963-
Z_train
964-
)
965-
has_basis_rfx <- TRUE
966-
num_basis_rfx <- 1 + ncol(Z_train)
946+
if (has_multivariate_treatment) {
947+
warning(
948+
"Random effects `intercept_plus_treatment` specification is not currently implemented for multivariate treatments. This model will be fit under the `intercept_only` specification instead. Please provide a custom `rfx_basis_train` if you wish to have random slopes on multivariate treatment variables."
949+
)
950+
rfx_model_spec <- "intercept_only"
951+
}
952+
}
953+
if (is.null(rfx_basis_train)) {
954+
if (rfx_model_spec == "intercept_only") {
955+
rfx_basis_train <- matrix(
956+
rep(1, nrow(X_train)),
957+
nrow = nrow(X_train),
958+
ncol = 1
959+
)
960+
has_basis_rfx <- TRUE
961+
num_basis_rfx <- 1
962+
} else {
963+
rfx_basis_train <- cbind(
964+
rep(1, nrow(X_train)),
965+
Z_train
966+
)
967+
has_basis_rfx <- TRUE
968+
num_basis_rfx <- 1 + ncol(Z_train)
969+
}
967970
}
968971
num_rfx_groups <- length(unique(rfx_group_ids_train))
969972
num_rfx_components <- ncol(rfx_basis_train)
@@ -1021,15 +1024,21 @@ bcf <- function(
10211024
y_train <- as.matrix(y_train)
10221025
}
10231026

1024-
# Check whether treatment is binary (specifically 0-1 binary)
1025-
binary_treatment <- length(unique(Z_train)) == 2
1026-
if (binary_treatment) {
1027-
unique_treatments <- sort(unique(Z_train))
1028-
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
1027+
# Check whether treatment is binary and univariate (specifically 0-1 binary)
1028+
binary_treatment <- FALSE
1029+
if (!has_multivariate_treatment) {
1030+
binary_treatment <- length(unique(Z_train)) == 2
1031+
if (binary_treatment) {
1032+
unique_treatments <- sort(unique(Z_train))
1033+
if (!(all(unique_treatments == c(0, 1)))) binary_treatment <- FALSE
1034+
}
10291035
}
10301036

10311037
# Adaptive coding will be ignored for continuous / ordered categorical treatments
10321038
if ((!binary_treatment) && (adaptive_coding)) {
1039+
warning(
1040+
"Adaptive coding is only compatible with binary (univariate) treatment and, as a result, will be ignored in sampling this model"
1041+
)
10331042
adaptive_coding <- FALSE
10341043
}
10351044

R/posterior_transformation.R

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,31 @@ compute_contrast_bart_model <- function(
409409
"rfx_group_ids_0 and rfx_group_ids_1 must be provided for this model"
410410
)
411411
}
412-
if ((has_rfx) && (is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
413-
stop(
414-
"rfx_basis_0 and rfx_basis_1 must be provided for this model"
415-
)
416-
}
417-
if (
418-
(object$model_params$num_rfx_basis > 0) &&
419-
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
420-
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
421-
) {
422-
stop(
423-
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
424-
)
412+
if (has_rfx) {
413+
if (object$model_params$rfx_model_spec == "custom") {
414+
if ((is.null(rfx_basis_0) || is.null(rfx_basis_1))) {
415+
stop(
416+
"A user-provided basis (`rfx_basis_0` and `rfx_basis_1`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
417+
)
418+
}
419+
if (!is.matrix(rfx_basis_0) || !is.matrix(rfx_basis_1)) {
420+
stop("'rfx_basis_0' and 'rfx_basis_1' must be matrices")
421+
}
422+
if ((nrow(rfx_basis_0) != nrow(X)) || (nrow(rfx_basis_1) != nrow(X))) {
423+
stop(
424+
"'rfx_basis_0' and 'rfx_basis_1' must have the same number of rows as 'X'"
425+
)
426+
}
427+
if (
428+
(object$model_params$num_rfx_basis > 0) &&
429+
((ncol(rfx_basis_0) != object$model_params$num_rfx_basis) ||
430+
(ncol(rfx_basis_1) != object$model_params$num_rfx_basis))
431+
) {
432+
stop(
433+
"rfx_basis_0 and / or rfx_basis_1 have a different dimension than the basis used to train this model"
434+
)
435+
}
436+
}
425437
}
426438

427439
# Predict for the control arm
@@ -574,16 +586,22 @@ sample_bcf_posterior_predictive <- function(
574586
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
575587
)
576588
}
577-
if (is.null(rfx_basis)) {
578-
stop(
579-
"'rfx_basis' must be provided in order to compute the requested intervals"
580-
)
581-
}
582-
if (!is.matrix(rfx_basis)) {
583-
stop("'rfx_basis' must be a matrix")
589+
590+
if (model_object$model_params$rfx_model_spec == "custom") {
591+
if (is.null(rfx_basis)) {
592+
stop(
593+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
594+
)
595+
}
584596
}
585-
if (nrow(rfx_basis) != nrow(X)) {
586-
stop("'rfx_basis' must have the same number of rows as 'X'")
597+
598+
if (!is.null(rfx_basis)) {
599+
if (!is.matrix(rfx_basis)) {
600+
stop("'rfx_basis' must be a matrix")
601+
}
602+
if (nrow(rfx_basis) != nrow(X)) {
603+
stop("'rfx_basis' must have the same number of rows as 'X'")
604+
}
587605
}
588606
}
589607

@@ -735,16 +753,18 @@ sample_bart_posterior_predictive <- function(
735753
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
736754
)
737755
}
738-
if (is.null(rfx_basis)) {
739-
stop(
740-
"'rfx_basis' must be provided in order to compute the requested intervals"
741-
)
742-
}
743-
if (!is.matrix(rfx_basis)) {
744-
stop("'rfx_basis' must be a matrix")
745-
}
746-
if (nrow(rfx_basis) != nrow(X)) {
747-
stop("'rfx_basis' must have the same number of rows as 'X'")
756+
if (model_object$model_params$rfx_model_spec == "custom") {
757+
if (is.null(rfx_basis)) {
758+
stop(
759+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
760+
)
761+
}
762+
if (!is.matrix(rfx_basis)) {
763+
stop("'rfx_basis' must be a matrix")
764+
}
765+
if (nrow(rfx_basis) != nrow(X)) {
766+
stop("'rfx_basis' must have the same number of rows as 'X'")
767+
}
748768
}
749769
}
750770

@@ -1172,16 +1192,18 @@ compute_bart_posterior_interval <- function(
11721192
"'rfx_group_ids' must have the same length as the number of rows in 'X'"
11731193
)
11741194
}
1175-
if (is.null(rfx_basis)) {
1176-
stop(
1177-
"'rfx_basis' must be provided in order to compute the requested intervals"
1178-
)
1179-
}
1180-
if (!is.matrix(rfx_basis)) {
1181-
stop("'rfx_basis' must be a matrix")
1182-
}
1183-
if (nrow(rfx_basis) != nrow(X)) {
1184-
stop("'rfx_basis' must have the same number of rows as 'X'")
1195+
if (model_object$model_params$rfx_model_spec == "custom") {
1196+
if (is.null(rfx_basis)) {
1197+
stop(
1198+
"A user-provided basis (`rfx_basis`) must be provided when the model was sampled with a random effects model spec set to 'custom'"
1199+
)
1200+
}
1201+
if (!is.matrix(rfx_basis)) {
1202+
stop("'rfx_basis' must be a matrix")
1203+
}
1204+
if (nrow(rfx_basis) != nrow(X)) {
1205+
stop("'rfx_basis' must have the same number of rows as 'X'")
1206+
}
11851207
}
11861208
}
11871209

configure

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! /bin/sh
22
# Guess values for system-dependent variables and create Makefiles.
3-
# Generated by GNU Autoconf 2.72 for stochtree 0.2.0.9000.
3+
# Generated by GNU Autoconf 2.72 for stochtree 0.2.1.
44
#
55
#
66
# Copyright (C) 1992-1996, 1998-2017, 2020-2023 Free Software Foundation,
@@ -600,8 +600,8 @@ MAKEFLAGS=
600600
# Identity of this package.
601601
PACKAGE_NAME='stochtree'
602602
PACKAGE_TARNAME='stochtree'
603-
PACKAGE_VERSION='0.2.0.9000'
604-
PACKAGE_STRING='stochtree 0.2.0.9000'
603+
PACKAGE_VERSION='0.2.1'
604+
PACKAGE_STRING='stochtree 0.2.1'
605605
PACKAGE_BUGREPORT=''
606606
PACKAGE_URL=''
607607

@@ -1204,7 +1204,7 @@ if test "$ac_init_help" = "long"; then
12041204
# Omit some internal or obsolete options to make the list less imposing.
12051205
# This message is too long to be a string in the A/UX 3.1 sh.
12061206
cat <<_ACEOF
1207-
'configure' configures stochtree 0.2.0.9000 to adapt to many kinds of systems.
1207+
'configure' configures stochtree 0.2.1 to adapt to many kinds of systems.
12081208
12091209
Usage: $0 [OPTION]... [VAR=VALUE]...
12101210
@@ -1266,7 +1266,7 @@ fi
12661266

12671267
if test -n "$ac_init_help"; then
12681268
case $ac_init_help in
1269-
short | recursive ) echo "Configuration of stochtree 0.2.0.9000:";;
1269+
short | recursive ) echo "Configuration of stochtree 0.2.1:";;
12701270
esac
12711271
cat <<\_ACEOF
12721272
@@ -1334,7 +1334,7 @@ fi
13341334
test -n "$ac_init_help" && exit $ac_status
13351335
if $ac_init_version; then
13361336
cat <<\_ACEOF
1337-
stochtree configure 0.2.0.9000
1337+
stochtree configure 0.2.1
13381338
generated by GNU Autoconf 2.72
13391339
13401340
Copyright (C) 2023 Free Software Foundation, Inc.
@@ -1371,7 +1371,7 @@ cat >config.log <<_ACEOF
13711371
This file contains any messages produced by compilers while
13721372
running configure, to aid debugging if configure makes a mistake.
13731373
1374-
It was created by stochtree $as_me 0.2.0.9000, which was
1374+
It was created by stochtree $as_me 0.2.1, which was
13751375
generated by GNU Autoconf 2.72. Invocation command line was
13761376
13771377
$ $0$ac_configure_args_raw
@@ -2380,7 +2380,7 @@ cat >>$CONFIG_STATUS <<\_ACEOF || ac_write_fail=1
23802380
# report actual input values of CONFIG_FILES etc. instead of their
23812381
# values after options handling.
23822382
ac_log="
2383-
This file was extended by stochtree $as_me 0.2.0.9000, which was
2383+
This file was extended by stochtree $as_me 0.2.1, which was
23842384
generated by GNU Autoconf 2.72. Invocation command line was
23852385
23862386
CONFIG_FILES = $CONFIG_FILES
@@ -2435,7 +2435,7 @@ ac_cs_config_escaped=`printf "%s\n" "$ac_cs_config" | sed "s/^ //; s/'/'\\\\\\\\
24352435
cat >>$CONFIG_STATUS <<_ACEOF || ac_write_fail=1
24362436
ac_cs_config='$ac_cs_config_escaped'
24372437
ac_cs_version="\\
2438-
stochtree config.status 0.2.0.9000
2438+
stochtree config.status 0.2.1
24392439
configured by $0, generated by GNU Autoconf 2.72,
24402440
with options \\"\$ac_cs_config\\"
24412441

configure.ac

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# https://github.com/microsoft/LightGBM/blob/master/R-package/configure.ac
44

55
AC_PREREQ(2.69)
6-
AC_INIT([stochtree], [0.2.0.9000], [], [stochtree], [])
6+
AC_INIT([stochtree], [0.2.1], [], [stochtree], [])
77
# Note: consider making version number dynamic as in
88
# https://github.com/microsoft/LightGBM/blob/195c26fc7b00eb0fec252dfe841e2e66d6833954/build-cran-package.sh
99

0 commit comments

Comments
 (0)