Add post-fit confidence intervals to BaseTLearner via store_bootstraps and return_ci#886
Add post-fit confidence intervals to BaseTLearner via store_bootstraps and return_ci#886aman-coder03 wants to merge 4 commits intouber:masterfrom
BaseTLearner via store_bootstraps and return_ci#886Conversation
jeongyoonlee
left a comment
There was a problem hiding this comment.
Thanks for the PR! The motivation — scoring new unseen data with CIs without re-fitting — is a real use case that the existing fit_predict() bootstrap doesn't cover. A few items to address:
Blocking
-
AttributeErrorbug —self.bootstrap_models_is never initialized in__init__. If a user callspredict(return_ci=True)on a freshly constructed (or non-bootstrap-fitted) instance, it raisesAttributeErrorinstead of the intendedValueError. Fix: initializeself.bootstrap_models_ = Nonein__init__and useself.bootstrap_models_ is None(explicit check) instead ofnot self.bootstrap_models_(falsy check). -
No reproducibility — The bootstrap loop uses
np.random.choicewith no seed. The existingBaseDRLearner.bootstrap()already supports a seed pattern — please follow that for consistency. -
ci_quantilevsate_alphainconsistency — The rest of the codebase usesself.ate_alpha(set at__init__time) to control CI width. This PR introduces a separateci_quantiledefaulting to0.05, so the same learner produces different CI widths fromfit_predict()vspredict(return_ci=True)unless the user remembers to align them. Recommendation: useself.ate_alphaby default, with an optional override parameter. -
Duplicated bootstrap logic — The PR reimplements ~20 lines of resampling/fitting that
BaseLearner.bootstrap()already handles. Consider reusing or extending the existing infrastructure to reduce maintenance burden. -
return_ci+return_componentsinteraction — If both areTrue,return_componentsis silently ignored. Should either raise an error or support both. -
Flaky test assertion —
assert (lb <= tau_pred).all() and (tau_pred <= ub).all()is not guaranteed for bootstrap percentile CIs — the full-sample point estimate can legitimately fall outside the bootstrap percentile interval. Consider checking that a high fraction (e.g., 95%) of predictions fall within bounds, or just check shapes and thatlb <= ub. -
Test conventions — Please use
RANDOM_SEEDandCONTROL_NAMEfromtests/const.pyper project conventions, and usepytest.raises(ValueError)instead of try/except.
Non-blocking suggestions
-
Memory cost — Default 1000 bootstraps storing full deep-copied model pairs can be significant (e.g., 5 treatments × 1000 bootstraps = 10k model objects). Consider documenting this or lowering the default.
-
BaseTClassifiernot updated —BaseTClassifieroverridespredict()with its own implementation. Passingreturn_ci=Trueto a classifier willTypeError. Worth documenting or handling. -
Docstrings — The new parameters (
store_bootstraps,n_bootstraps,bootstrap_size,return_ci,ci_quantile) should be documented in thefit()andpredict()docstrings. -
Architecture note — You mentioned planning to extend this to other learners. Since the store-and-reuse pattern is generic, it might be worth designing the base class integration upfront (even if implementation is phased) to avoid divergent implementations across learners.
Proposed changes
currently CIs are only available via
fit_predict(..., return_ci=True), which re-runsbootstrap resampling every time. This makes it impossible to get CIs on new data aftera model is already trained, which is a common need in production where you train once and score repeatedlyHow
added
store_bootstraps=Falsetofit(). When set to True, it trains a bootstrap ensemble andstores it inself.bootstrap_models_. Addedreturn_ci=Falsetopredict(). When set to True,it scores the stored ensemble and returns(te, te_lower, te_upper)with no retraining neededall existing calls to
fit(),predict()andfit_predict()are completely unchanged since the new parameters are opt-in only.Scope
this PR covers
BaseTLearneras the reference implementation. I plan to extend the same pattern to the remaining learners once the design is confirmed by maintainersCloses #885
Types of changes
What types of changes does your code introduce to CausalML?
Put an
xin the boxes that applyChecklist
Put an
xin the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.Further comments
If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc. This PR template is adopted from appium.