Skip to content

Commit 096ceb6

Browse files
authored
Merge pull request #256 from StochasticTree/rfx-non-default-group-ids
Fix prediction bugs for models with "non-default" random effects IDs which rely on a `model_spec` specification
2 parents b63faed + 174fc27 commit 096ceb6

File tree

10 files changed

+1670
-16
lines changed

10 files changed

+1670
-16
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
# Changelog
22

3+
# stochtree 0.2.1.9000
4+
5+
## Bug Fixes
6+
7+
* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))
8+
39
# stochtree 0.2.1
410

511
## Bug Fixes
612

713
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
14+
* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256))
815

916
## Other Changes
1017

NEWS.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
# stochtree 0.2.1.9000
2+
3+
## Bug Fixes
4+
5+
* Fix prediction bug for R BART models with random effects with labels that aren't straightforward `1:num_groups` integers when only `y_hat` is requested ([#256](https://github.com/StochasticTree/stochtree/pull/256))
6+
17
# stochtree 0.2.1
28

39
## Bug Fixes
410

511
* Fix prediction bug for univariate random effects models in R ([#248](https://github.com/StochasticTree/stochtree/pull/248))
12+
* Fix prediction bug for Python BART and BCF models with random effects with labels that aren't straightforward `0:(num_groups-1)` integers ([#256](https://github.com/StochasticTree/stochtree/pull/256))
613

714
## Other Changes
815

R/bart.R

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,17 +2134,15 @@ predict.bartmodel <- function(
21342134
X <- preprocessPredictionData(X, train_set_metadata)
21352135

21362136
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
2137-
if (predict_rfx) {
2138-
if (!is.null(rfx_group_ids)) {
2139-
rfx_unique_group_ids <- object$rfx_unique_group_ids
2140-
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
2141-
if (sum(is.na(group_ids_factor)) > 0) {
2142-
stop(
2143-
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
2144-
)
2145-
}
2146-
rfx_group_ids <- as.integer(group_ids_factor)
2137+
if (!is.null(rfx_group_ids)) {
2138+
rfx_unique_group_ids <- object$rfx_unique_group_ids
2139+
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
2140+
if (sum(is.na(group_ids_factor)) > 0) {
2141+
stop(
2142+
"All random effect group labels provided in rfx_group_ids must have been present in rfx_group_ids_train"
2143+
)
21472144
}
2145+
rfx_group_ids <- as.integer(group_ids_factor)
21482146
}
21492147

21502148
# Handle RFX model specification

R/bcf.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3040,7 +3040,6 @@ predict.bcfmodel <- function(
30403040
X <- preprocessPredictionData(X, train_set_metadata)
30413041

30423042
# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
3043-
has_rfx <- FALSE
30443043
if (!is.null(rfx_group_ids)) {
30453044
rfx_unique_group_ids <- object$rfx_unique_group_ids
30463045
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
@@ -3050,7 +3049,6 @@ predict.bcfmodel <- function(
30503049
)
30513050
}
30523051
rfx_group_ids <- as.integer(group_ids_factor)
3053-
has_rfx <- TRUE
30543052
}
30553053

30563054
# Handle RFX model specification

src/py_stochtree.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,18 @@ class RandomEffectsLabelMapperCpp {
15731573
StochTree::LabelMapper* GetLabelMapper() {
15741574
return rfx_label_mapper_.get();
15751575
}
1576+
int MapGroupIdToArrayIndex(int original_label) {
1577+
return rfx_label_mapper_->CategoryNumber(original_label);
1578+
}
1579+
py::array_t<int> MapMultipleGroupIdsToArrayIndices(py::array_t<int> original_labels) {
1580+
int output_size = original_labels.size();
1581+
auto result = py::array_t<int>(py::detail::any_container<py::ssize_t>({output_size}));
1582+
auto accessor = result.mutable_unchecked<1>();
1583+
for (int i = 0; i < output_size; i++) {
1584+
accessor(i) = rfx_label_mapper_->CategoryNumber(original_labels.at(i));
1585+
}
1586+
return result;
1587+
}
15761588

15771589
private:
15781590
std::unique_ptr<StochTree::LabelMapper> rfx_label_mapper_;
@@ -2410,7 +2422,9 @@ PYBIND11_MODULE(stochtree_cpp, m) {
24102422
.def("DumpJsonString", &RandomEffectsLabelMapperCpp::DumpJsonString)
24112423
.def("LoadFromJsonString", &RandomEffectsLabelMapperCpp::LoadFromJsonString)
24122424
.def("LoadFromJson", &RandomEffectsLabelMapperCpp::LoadFromJson)
2413-
.def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper);
2425+
.def("GetLabelMapper", &RandomEffectsLabelMapperCpp::GetLabelMapper)
2426+
.def("MapGroupIdToArrayIndex", &RandomEffectsLabelMapperCpp::MapGroupIdToArrayIndex)
2427+
.def("MapMultipleGroupIdsToArrayIndices", &RandomEffectsLabelMapperCpp::MapMultipleGroupIdsToArrayIndices);
24142428

24152429
py::class_<RandomEffectsModelCpp>(m, "RandomEffectsModelCpp")
24162430
.def(py::init<int, int>())

stochtree/bart.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,12 @@ def predict(
19881988
"Random effects basis has a different dimension than the basis used to train this model"
19891989
)
19901990

1991+
# Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays
1992+
if rfx_group_ids is not None:
1993+
rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices(
1994+
rfx_group_ids
1995+
)
1996+
19911997
# Random effects predictions
19921998
if predict_rfx or predict_rfx_intermediate:
19931999
if rfx_basis is not None:
@@ -2017,7 +2023,7 @@ def predict(
20172023
)
20182024
for i in range(n_train):
20192025
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[
2020-
rfx_group_ids[i], :
2026+
rfx_group_id_indices[i], :
20212027
]
20222028
rfx_predictions = np.squeeze(rfx_predictions_raw[:, 0, :])
20232029

stochtree/bcf.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,6 +3053,12 @@ def predict(
30533053
raise ValueError(
30543054
"rfx_basis must have the same number of columns as the random effects basis used to sample this model"
30553055
)
3056+
3057+
# Convert rfx_group_ids to their corresponding array position indices in the random effects parameter sample arrays
3058+
if rfx_group_ids is not None:
3059+
rfx_group_id_indices = self.rfx_container.map_group_ids_to_array_indices(
3060+
rfx_group_ids
3061+
)
30563062

30573063
# Random effects predictions
30583064
if predict_rfx or predict_rfx_intermediate:
@@ -3073,14 +3079,14 @@ def predict(
30733079
)
30743080
for i in range(X.shape[0]):
30753081
rfx_predictions_raw[i, :, :] = rfx_beta_draws[
3076-
:, rfx_group_ids[i], :
3082+
:, rfx_group_id_indices[i], :
30773083
]
30783084
elif rfx_beta_draws.ndim == 2:
30793085
rfx_predictions_raw = np.empty(
30803086
shape=(X.shape[0], 1, rfx_beta_draws.shape[1])
30813087
)
30823088
for i in range(X.shape[0]):
3083-
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_ids[i], :]
3089+
rfx_predictions_raw[i, 0, :] = rfx_beta_draws[rfx_group_id_indices[i], :]
30843090
else:
30853091
raise ValueError(
30863092
"Unexpected number of dimensions in extracted random effects samples"

stochtree/random_effects.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,38 @@ def extract_parameter_samples(self) -> dict[str, np.ndarray]:
418418
"sigma_samples": sigma_samples,
419419
}
420420
return output
421+
422+
def map_group_id_to_array_index(self, group_id: int) -> int:
423+
"""
424+
Map an integer-valued random effects group ID to its group's corresponding position in the arrays that store random effects parameter samples.
425+
426+
Parameters
427+
----------
428+
group_id : int
429+
Group identifier to be converted to an array position.
430+
431+
Returns
432+
-------
433+
int
434+
The position of `group_id` in the parameter sample arrays underlying the random effects container.
435+
"""
436+
return self.rfx_label_mapper_cpp.MapGroupIdToArrayIndex(group_id)
437+
438+
def map_group_ids_to_array_indices(self, group_ids: np.ndarray) -> np.ndarray:
439+
"""
440+
Map an array of integer-valued random effects group IDs to their groups' corresponding positions in the arrays that store random effects parameter samples.
441+
442+
Parameters
443+
----------
444+
group_ids : np.ndarray
445+
Array of group identifiers (integer-valued) to be converted to an array position.
446+
447+
Returns
448+
-------
449+
np.ndarray
450+
Numpy array of the position of `group_id` in the parameter sample arrays underlying the random effects container.
451+
"""
452+
return self.rfx_label_mapper_cpp.MapMultipleGroupIdsToArrayIndices(group_ids)
421453

422454

423455
class RandomEffectsModel:

0 commit comments

Comments
 (0)