From d688549b192ff1aba79d82cbd897a509d4928e50 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 13:54:48 -0800 Subject: [PATCH 1/5] Add .worktrees/ to .gitignore Co-Authored-By: Claude Opus 4.6 --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2d446e8a..584da151 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,6 @@ uv.lock # Claude Code artifacts CLAUDE.md -.claude/ \ No newline at end of file +.claude/.worktrees/ +.worktrees/ +.worktrees/ From 0a55cd2092427fe0d07740da6088ffe64c9ca170 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:10:10 -0800 Subject: [PATCH 2/5] Fix CausalRandomForestRegressor predicting inf from division by zero (#589) Guard against zero treatment/control counts in CausalMSE and TTest criterion functions. When a tree split creates a child node with no treatment or no control observations, the variance formula `var/count` produces infinity. Now skips impurity contribution for that treatment group (zero impurity), preventing the splitter from favoring degenerate splits. Affected methods: - CausalMSE.node_impurity() - CausalMSE.children_impurity() - TTest.children_impurity() Co-Authored-By: Claude Opus 4.6 --- causalml/inference/tree/causal/_criterion.pyx | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/causalml/inference/tree/causal/_criterion.pyx b/causalml/inference/tree/causal/_criterion.pyx index 7c49599f..81992044 100755 --- a/causalml/inference/tree/causal/_criterion.pyx +++ b/causalml/inference/tree/causal/_criterion.pyx @@ -463,7 +463,8 @@ cdef class CausalMSE(CausalRegressionCriterion): tr_var = self.state.node.outcome_var(tr_group_idx) tr_count = self.state.node.count_1d[tr_group_idx] - impurity += (tr_var / tr_count + ct_var / ct_count) - node_tau * node_tau + if tr_count > 0 and ct_count > 0: + impurity += (tr_var / tr_count + ct_var / ct_count) - node_tau * node_tau impurity /= (self.n_outputs - 1) impurity += self.get_groups_penalty(self.state.node) @@ -500,8 +501,10 @@ cdef class CausalMSE(CausalRegressionCriterion): left_tr_var = self.state.left.outcome_var(tr_group_idx) left_tr_count = self.state.left.count_1d[tr_group_idx] - impurity_right[0] += (right_tr_var / right_tr_count + right_ct_var / right_ct_count) - right_tau * right_tau - impurity_left[0] += (left_tr_var / left_tr_count + left_ct_var / left_ct_count) - left_tau * left_tau + if right_tr_count > 0 and right_ct_count > 0: + impurity_right[0] += (right_tr_var / right_tr_count + right_ct_var / right_ct_count) - right_tau * right_tau + if left_tr_count > 0 and left_ct_count > 0: + impurity_left[0] += (left_tr_var / left_tr_count + left_ct_var / left_ct_count) - left_tau * left_tau impurity_right[0] /= (self.n_outputs - 1) impurity_left[0] /= (self.n_outputs - 1) @@ -577,16 +580,22 @@ cdef class TTest(CausalRegressionCriterion): left_tr_var = self.state.left.outcome_var(tr_group_idx) left_tr_count = self.state.left.count_1d[tr_group_idx] - denom_left = sqrt(left_tr_var / left_tr_count + left_ct_var / left_ct_count) - denom_right = sqrt(right_tr_var / right_tr_count + right_ct_var / right_ct_count) + denom_left = 0.0 + denom_right = 0.0 + if left_tr_count > 0 and left_ct_count > 0: + denom_left = sqrt(left_tr_var / left_tr_count + left_ct_var / left_ct_count) + if right_tr_count > 0 and right_ct_count > 0: + denom_right = sqrt(right_tr_var / right_tr_count + right_ct_var / right_ct_count) if denom_left > 0.: t_left_sum += left_tau / denom_left if denom_right > 0.: t_right_sum += right_tau / denom_right - + # Per-treatment squared difference in taus between sides - inv_n_sum = (1.0 / right_tr_count + 1.0 / right_ct_count + - 1.0 / left_tr_count + 1.0 / left_ct_count) + inv_n_sum = 0.0 + if right_tr_count > 0 and right_ct_count > 0 and left_tr_count > 0 and left_ct_count > 0: + inv_n_sum = (1.0 / right_tr_count + 1.0 / right_ct_count + + 1.0 / left_tr_count + 1.0 / left_ct_count) # Pooled variance across four cells (left/right × tr/ct) pooled_var_t = 0.0 From 9caf6c2ff364a741b89b54947d7d7785212e120a Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:31:38 -0800 Subject: [PATCH 3/5] Add regression test for inf predictions with sparse groups (#589) Test that CausalRandomForestRegressor.predict() returns finite values when imbalanced data causes zero-count treatment/control nodes. Co-Authored-By: Claude Opus 4.6 --- tests/test_causal_trees.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_causal_trees.py b/tests/test_causal_trees.py index 1543deca..07e557e2 100644 --- a/tests/test_causal_trees.py +++ b/tests/test_causal_trees.py @@ -275,3 +275,27 @@ def test_unbiased_sampling_error( crforest_test_var = crforest.calculate_error(X_train=X_train, X_test=X_test) assert (crforest_test_var > 0).all() assert crforest_test_var.shape[0] == y_test.shape[0] + + +def test_CausalRandomForestRegressor_no_inf_predictions(): + """Test that CausalRandomForestRegressor does not predict inf values + when some tree splits have zero-count treatment/control groups (#589).""" + np.random.seed(RANDOM_SEED) + n = 100 + X = np.random.randn(n, 5) + # Heavily imbalanced: very few treated samples so tree splits + # can produce nodes with zero treatment count + treatment = np.array([0] * 90 + [1] * 10) + y = np.random.randn(n) + + model = CausalRandomForestRegressor( + criterion="causal_mse", + control_name=0, + n_estimators=10, + min_samples_leaf=1, + random_state=RANDOM_SEED, + ) + model.fit(X=X, treatment=treatment, y=y) + preds = model.predict(X=X) + + assert np.all(np.isfinite(preds)), "Predictions contain inf or NaN values" From da1e25f517ab1ffab7a7d1ad4af189810c49d026 Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:38:20 -0800 Subject: [PATCH 4/5] Add ttest criterion regression test for inf predictions (#589) Co-Authored-By: Claude Opus 4.6 --- tests/test_causal_trees.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_causal_trees.py b/tests/test_causal_trees.py index 07e557e2..46d9a8dd 100644 --- a/tests/test_causal_trees.py +++ b/tests/test_causal_trees.py @@ -299,3 +299,26 @@ def test_CausalRandomForestRegressor_no_inf_predictions(): preds = model.predict(X=X) assert np.all(np.isfinite(preds)), "Predictions contain inf or NaN values" + + +def test_CausalRandomForestRegressor_no_inf_predictions_ttest(): + """Test that CausalRandomForestRegressor with criterion='ttest' does not + predict inf values when some tree splits have zero-count + treatment/control groups (#589).""" + np.random.seed(RANDOM_SEED) + n = 100 + X = np.random.randn(n, 5) + treatment = np.array([0] * 90 + [1] * 10) + y = np.random.randn(n) + + model = CausalRandomForestRegressor( + criterion="ttest", + control_name=0, + n_estimators=10, + min_samples_leaf=1, + random_state=RANDOM_SEED, + ) + model.fit(X=X, treatment=treatment, y=y) + preds = model.predict(X=X) + + assert np.all(np.isfinite(preds)), "Predictions contain inf or NaN values" From 1099300a1e2c6ffd612e7d6892b371b0b90d6b9c Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Fri, 6 Mar 2026 16:47:02 -0800 Subject: [PATCH 5/5] Fix ttest criterion name: 'ttest' -> 't_test' Co-Authored-By: Claude Opus 4.6 --- tests/test_causal_trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_causal_trees.py b/tests/test_causal_trees.py index 46d9a8dd..8371df92 100644 --- a/tests/test_causal_trees.py +++ b/tests/test_causal_trees.py @@ -312,7 +312,7 @@ def test_CausalRandomForestRegressor_no_inf_predictions_ttest(): y = np.random.randn(n) model = CausalRandomForestRegressor( - criterion="ttest", + criterion="t_test", control_name=0, n_estimators=10, min_samples_leaf=1,