Skip to content

Commit 10b0f2a

Browse files
committed
Fixed bug in R BART when predicting y_hat only for models with non-default RFX IDs
1 parent b63faed commit 10b0f2a

File tree

3 files changed

+719
-12
lines changed

3 files changed

+719
-12
lines changed

R/bart.R

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,17 +2134,15 @@ predict.bartmodel <- function(
21342134
X <- preprocessPredictionData(X, train_set_metadata)
21352135

21362136
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
2137-
if (predict_rfx) {
2138-
if (!is.null(rfx_group_ids)) {
2139-
rfx_unique_group_ids <- object$rfx_unique_group_ids
2140-
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
2141-
if (sum(is.na(group_ids_factor)) > 0) {
2142-
stop(
2143-
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
2144-
)
2145-
}
2146-
rfx_group_ids <- as.integer(group_ids_factor)
2137+
if (!is.null(rfx_group_ids)) {
2138+
rfx_unique_group_ids <- object$rfx_unique_group_ids
2139+
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
2140+
if (sum(is.na(group_ids_factor)) > 0) {
2141+
stop(
2142+
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
2143+
)
21472144
}
2145+
rfx_group_ids <- as.integer(group_ids_factor)
21482146
}
21492147

21502148
# Handle RFX model specification

R/bcf.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,7 +3040,6 @@ predict.bcfmodel <- function(
30403040
X <- preprocessPredictionData(X, train_set_metadata)
30413041

30423042
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
3043-
has_rfx <- FALSE
30443043
if (!is.null(rfx_group_ids)) {
30453044
rfx_unique_group_ids <- object$rfx_unique_group_ids
30463045
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
@@ -3050,7 +3049,6 @@ predict.bcfmodel <- function(
30503049
)
30513050
}
30523051
rfx_group_ids <- as.integer(group_ids_factor)
3053-
has_rfx <- TRUE
30543052
}
30553053

30563054
# Handle RFX model specification

0 commit comments

Comments
 (0)