diff --git a/bigframes/ml/model_selection.py b/bigframes/ml/model_selection.py index 6eba4f81c2..cc3086dca4 100644 --- a/bigframes/ml/model_selection.py +++ b/bigframes/ml/model_selection.py @@ -117,6 +117,9 @@ def _stratify_split(df: bpd.DataFrame, stratify: bpd.Series) -> List[bpd.DataFra else: joined_df_train, joined_df_test = _stratify_split(joined_df, stratify) + joined_df_train = joined_df_train.cache() + joined_df_test = joined_df_test.cache() + results = [] for array in arrays: columns = array.name if isinstance(array, bpd.Series) else array.columns diff --git a/tests/system/small/ml/test_model_selection.py b/tests/system/small/ml/test_model_selection.py index ebce6e405a..992a884f64 100644 --- a/tests/system/small/ml/test_model_selection.py +++ b/tests/system/small/ml/test_model_selection.py @@ -46,6 +46,24 @@ def test_train_test_split_default_correct_shape(df_fixture, request): assert y_test.shape == (86, 1) +def test_train_test_split_default_unordered_same_index( + unordered_session, penguins_pandas_df_default_index +): + df = unordered_session.read_pandas(penguins_pandas_df_default_index) + X = df[ + [ + "species", + "island", + "culmen_length_mm", + ] + ] + y = df[["body_mass_g"]] + X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y) + + pd.testing.assert_index_equal(X_train.to_pandas().index, y_train.to_pandas().index) + pd.testing.assert_index_equal(X_test.to_pandas().index, y_test.to_pandas().index) + + def test_train_test_split_series_default_correct_shape(penguins_df_default_index): X = penguins_df_default_index[["species"]] y = penguins_df_default_index["body_mass_g"]