From b77f04f4d9136d58e7dda129e4bce216c10a9a12 Mon Sep 17 00:00:00 2001 From: BaptisteDE Date: Mon, 30 Jun 2025 16:22:45 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20ExpressionCombine=20readab?= =?UTF-8?q?ility=20and=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_processing.py | 16 ++++++++++++++++ tide/processing.py | 10 +++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_processing.py b/tests/test_processing.py index 6326829..518d340 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -831,6 +831,22 @@ def test_combiner(self): assert res.shape == (3, 6) check_feature_names_out(combiner, res) + combiner_cond = ExpressionCombine( + columns_dict={ + "T1": "Text__°C__outdoor", + }, + expression="(T1 > 10) * 1", + result_column_name="where_test_01__hvac", + ) + + res = combiner_cond.fit_transform(test_df.copy()) + check_feature_names_out(combiner_cond, res) + np.testing.assert_almost_equal( + res["where_test_01__hvac"], + [0, 0, 0], + decimal=1, + ) + @patch("tide.base.get_oikolab_df", side_effect=mock_get_oikolab_df) def test_fill_oiko_meteo(self, mock_get_oikolab): data = pd.read_csv( diff --git a/tide/processing.py b/tide/processing.py index 742e784..7667114 100644 --- a/tide/processing.py +++ b/tide/processing.py @@ -2044,11 +2044,11 @@ def _fit_implementation(self, X, y=None): self.feature_names_out_.append(self.result_column_name) def _transform_implementation(self, X: pd.Series | pd.DataFrame): - exp = self.expression - for key, val in self.columns_dict.items(): - exp = exp.replace(key, f'X["{val}"]') - - X.loc[:, self.result_column_name] = pd.eval(exp, target=X) + X.loc[:, self.result_column_name] = pd.eval( + self.expression, + target=X, + local_dict={var: X[col] for var, col in self.columns_dict.items()}, + ) return X[self.feature_names_out_]