|
63 | 63 | y_test = y[ind_test] |
64 | 64 |
|
65 | 65 |
|
| 66 | +# Create 4-dim data |
| 67 | +np.random.seed(42) |
| 68 | +X_train_4d = np.random.normal(size=(400, 8, 8, 3)) |
| 69 | +X_test_4d = np.random.normal(size=(100, 8, 8, 3)) |
| 70 | +y_train_4d = np.random.randint(n_classes, size=400) |
| 71 | + |
| 72 | +# Reshape 4-dim to 2-dim |
| 73 | +X_train_4d_unrolled = X_train_4d.reshape(X_train_4d.shape[0], -1) |
| 74 | +X_test_4d_unrolled = X_test_4d.reshape(X_test_4d.shape[0], -1) |
| 75 | + |
| 76 | +#------------------------------------------------------------------------------ |
| 77 | +#------------------------------------------------------------------------------ |
| 78 | + |
| 79 | +class LogisticRegressionUnrolled(LogisticRegression): |
| 80 | + """ |
| 81 | + For tests related to N-dim input. |
| 82 | + Estimator accepts N-dim array and reshape it to 2-dim array |
| 83 | + """ |
| 84 | + def fit(self, X, y): |
| 85 | + return super(LogisticRegressionUnrolled, self).fit(X.reshape(X.shape[0], -1), y) |
| 86 | + |
| 87 | + def predict(self, X): |
| 88 | + return super(LogisticRegressionUnrolled, self).predict(X.reshape(X.shape[0], -1)) |
| 89 | + |
| 90 | + def predict_proba(self, X): |
| 91 | + return super(LogisticRegressionUnrolled, self).predict_proba(X.reshape(X.shape[0], -1)) |
| 92 | + |
66 | 93 | #------------------------------------------------------------------------------- |
67 | 94 | #------------------------------------------------------------------------------- |
68 | 95 |
|
@@ -775,7 +802,48 @@ def test_oof_pred_bag_mode_proba_2_models(self): |
775 | 802 |
|
776 | 803 | assert_array_equal(S_train_1, S_train_3) |
777 | 804 | assert_array_equal(S_test_1, S_test_3) |
| 805 | + |
| 806 | + def test_N_dim_input(self): |
| 807 | + """ |
| 808 | + This is `test_oof_pred_bag_mode` function with `LogisticRegressionUnrolled` estimator |
| 809 | + """ |
| 810 | + S_test_temp = np.zeros((X_test_4d_unrolled.shape[0], n_folds)) |
| 811 | + # Usind StratifiedKFold because by defauld cross_val_predict uses StratifiedKFold |
| 812 | + kf = StratifiedKFold(n_splits = n_folds, shuffle = False, random_state = 0) |
| 813 | + for fold_counter, (tr_index, te_index) in enumerate(kf.split(X_train_4d_unrolled, y_train_4d)): |
| 814 | + # Split data and target |
| 815 | + X_tr = X_train_4d_unrolled[tr_index] |
| 816 | + y_tr = y_train_4d[tr_index] |
| 817 | + X_te = X_train_4d_unrolled[te_index] |
| 818 | + y_te = y_train_4d[te_index] |
| 819 | + model = LogisticRegression(random_state=0, solver='liblinear', multi_class='ovr') |
| 820 | + _ = model.fit(X_tr, y_tr) |
| 821 | + S_test_temp[:, fold_counter] = model.predict(X_test_4d_unrolled) |
| 822 | + S_test_1 = st.mode(S_test_temp, axis = 1)[0] |
778 | 823 |
|
| 824 | + model = LogisticRegression(random_state=0, solver='liblinear', multi_class='ovr') |
| 825 | + S_train_1 = cross_val_predict(model, X_train_4d_unrolled, y = y_train_4d, cv = n_folds, |
| 826 | + n_jobs = 1, verbose = 0, method = 'predict').reshape(-1, 1) |
| 827 | + |
| 828 | + models = [LogisticRegressionUnrolled(random_state=0, solver='liblinear', multi_class='ovr')] |
| 829 | + S_train_2, S_test_2 = stacking(models, X_train_4d, y_train_4d, X_test_4d, |
| 830 | + regression = False, n_folds = n_folds, shuffle = False, save_dir=temp_dir, |
| 831 | + mode = 'oof_pred_bag', random_state = 0, verbose = 0, stratified = True) |
| 832 | + |
| 833 | + # Load OOF from file |
| 834 | + # Normally if cleaning is performed there is only one .npy file at given moment |
| 835 | + # But if we have no cleaning there may be more then one file so we take the latest |
| 836 | + file_name = sorted(glob.glob(os.path.join(temp_dir, '*.npy')))[-1] # take the latest file |
| 837 | + S = np.load(file_name) |
| 838 | + S_train_3 = S[0] |
| 839 | + S_test_3 = S[1] |
| 840 | + |
| 841 | + assert_array_equal(S_train_1, S_train_2) |
| 842 | + assert_array_equal(S_test_1, S_test_2) |
| 843 | + |
| 844 | + assert_array_equal(S_train_1, S_train_3) |
| 845 | + assert_array_equal(S_test_1, S_test_3) |
| 846 | + |
779 | 847 | #------------------------------------------------------------------------------- |
780 | 848 | #------------------------------------------------------------------------------- |
781 | 849 |
|
|
0 commit comments