Skip to content

Add post-fit confidence intervals to BaseTLearner via store_bootstraps and return_ci#886

Open
aman-coder03 wants to merge 4 commits intouber:masterfrom
aman-coder03:feature/predict-return-ci
Open

Add post-fit confidence intervals to BaseTLearner via store_bootstraps and return_ci#886
aman-coder03 wants to merge 4 commits intouber:masterfrom
aman-coder03:feature/predict-return-ci

Conversation

@aman-coder03
Copy link
Contributor

@aman-coder03 aman-coder03 commented Mar 7, 2026

Proposed changes

currently CIs are only available via fit_predict(..., return_ci=True), which re-runsbootstrap resampling every time. This makes it impossible to get CIs on new data aftera model is already trained, which is a common need in production where you train once and score repeatedly

How

added store_bootstraps=False to fit(). When set to True, it trains a bootstrap ensemble andstores it in self.bootstrap_models_. Added return_ci=False to predict(). When set to True,it scores the stored ensemble and returns (te, te_lower, te_upper) with no retraining needed

learner.fit(X_train, treatment_train, y_train, store_bootstraps=True, n_bootstraps=200)
tau, lb, ub = learner.predict(X_test, return_ci=True)

all existing calls to fit(), predict() and fit_predict() are completely unchanged since the new parameters are opt-in only.

Scope

this PR covers BaseTLearner as the reference implementation. I plan to extend the same pattern to the remaining learners once the design is confirmed by maintainers
Closes #885

Types of changes

What types of changes does your code introduce to CausalML?
Put an x in the boxes that apply

  • Bugfix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation Update (if none of the other choices apply)

Checklist

Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.

  • I have read the CONTRIBUTING doc
  • I have signed the CLA
  • Lint and unit tests pass locally with my changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have added necessary documentation (if appropriate)
  • Any dependent changes have been merged and published in downstream modules

Further comments

If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc. This PR template is adopted from appium.

Copy link
Collaborator

@jeongyoonlee jeongyoonlee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The motivation — scoring new unseen data with CIs without re-fitting — is a real use case that the existing fit_predict() bootstrap doesn't cover. A few items to address:

Blocking

  1. AttributeError bugself.bootstrap_models_ is never initialized in __init__. If a user calls predict(return_ci=True) on a freshly constructed (or non-bootstrap-fitted) instance, it raises AttributeError instead of the intended ValueError. Fix: initialize self.bootstrap_models_ = None in __init__ and use self.bootstrap_models_ is None (explicit check) instead of not self.bootstrap_models_ (falsy check).

  2. No reproducibility — The bootstrap loop uses np.random.choice with no seed. The existing BaseDRLearner.bootstrap() already supports a seed pattern — please follow that for consistency.

  3. ci_quantile vs ate_alpha inconsistency — The rest of the codebase uses self.ate_alpha (set at __init__ time) to control CI width. This PR introduces a separate ci_quantile defaulting to 0.05, so the same learner produces different CI widths from fit_predict() vs predict(return_ci=True) unless the user remembers to align them. Recommendation: use self.ate_alpha by default, with an optional override parameter.

  4. Duplicated bootstrap logic — The PR reimplements ~20 lines of resampling/fitting that BaseLearner.bootstrap() already handles. Consider reusing or extending the existing infrastructure to reduce maintenance burden.

  5. return_ci + return_components interaction — If both are True, return_components is silently ignored. Should either raise an error or support both.

  6. Flaky test assertionassert (lb <= tau_pred).all() and (tau_pred <= ub).all() is not guaranteed for bootstrap percentile CIs — the full-sample point estimate can legitimately fall outside the bootstrap percentile interval. Consider checking that a high fraction (e.g., 95%) of predictions fall within bounds, or just check shapes and that lb <= ub.

  7. Test conventions — Please use RANDOM_SEED and CONTROL_NAME from tests/const.py per project conventions, and use pytest.raises(ValueError) instead of try/except.

Non-blocking suggestions

  • Memory cost — Default 1000 bootstraps storing full deep-copied model pairs can be significant (e.g., 5 treatments × 1000 bootstraps = 10k model objects). Consider documenting this or lowering the default.

  • BaseTClassifier not updatedBaseTClassifier overrides predict() with its own implementation. Passing return_ci=True to a classifier will TypeError. Worth documenting or handling.

  • Docstrings — The new parameters (store_bootstraps, n_bootstraps, bootstrap_size, return_ci, ci_quantile) should be documented in the fit() and predict() docstrings.

  • Architecture note — You mentioned planning to extend this to other learners. Since the store-and-reuse pattern is generic, it might be worth designing the base class integration upfront (even if implementation is phased) to avoid divergent implementations across learners.

@jeongyoonlee jeongyoonlee added the enhancement New feature or request label Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Enhancement] Add return_ci support to predict() for post-fit confidence interval inference in meta-learners

2 participants