Skip to content

Commit b335c56

Browse files
ryan-odearemlapmot
andauthored
Offloading (#30)
* fix to compevent models * prepare gitignore * add joblib * add offload (and docs for visit) * create offloader * setup offloading in primary API * offloader to init * add boot idx * add intake from unloaded models * setup weight offloading * weight offload to init * add weight offloading to SEQ + weight inloading * test offload * bump version * setup offloading for dataframes * offload original DT while bootstrapping * adjust test nboot * skipping compevent tests * formatted * Use smf.logit() for binary treatment vars When: * treatment_level=[0, 1] (binary control vs treatment) * method="censoring" * weighted=True * Handle intercept-only formula when numerator is "1" or empty * Allow specifying fitting method * Obtain expected category levels from fitted model * Improve handling of categories for predictions * Account for NaNs in predicted probs * Make survival preds safe * Move _safe_predict into a helper file --------- Co-authored-by: Tom Palmer <remlapmot@hotmail.com>
1 parent b4765f1 commit b335c56

18 files changed

Lines changed: 352 additions & 44 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,6 @@ cython_debug/
166166

167167
# uv lock file
168168
uv.lock
169+
170+
# offloaded data files (offload test)
171+
_seq_models/

pySEQTarget/SEQopts.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import multiprocessing
2+
import os
23
from dataclasses import dataclass, field
34
from typing import List, Literal, Optional
45

@@ -18,7 +19,7 @@ class SEQopts:
1819
:type bootstrap_CI_method: str
1920
:param cense_colname: Column name for censoring effect (LTFU, etc.)
2021
:type cense_colname: str
21-
:param cense_denominator: Override to specify denominator patsy formula for censoring models
22+
:param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
2223
:type cense_denominator: Optional[str] or None
2324
:param cense_numerator: Override to specify numerator patsy formula for censoring models
2425
:type cense_numerator: Optional[str] or None
@@ -54,8 +55,12 @@ class SEQopts:
5455
:type km_curves: bool
5556
:param ncores: Number of cores to use if running in parallel
5657
:type ncores: int
57-
:param numerator: Override to specify the outcome patsy formula for numerator models
58+
:param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
5859
:type numerator: str
60+
:param offload: Boolean to offload intermediate model data to disk
61+
:type offload: bool
62+
:param offload_dir: Directory to offload intermediate model data
63+
:type offload_dir: str
5964
:param parallel: Boolean to run model fitting in parallel
6065
:type parallel: bool
6166
:param plot_colors: List of colors for KM plots, if applicable
@@ -80,8 +85,12 @@ class SEQopts:
8085
:type treatment_level: List[int]
8186
:param trial_include: Boolean to force trial values into model covariates
8287
:type trial_include: bool
88+
:param visit_colname: Column name specifying visit number
89+
:type visit_colname: str
8390
:param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
8491
:type weight_eligible_colnames: List[str]
92+
:param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
93+
:type weight_fit_method: str
8594
:param weight_min: Minimum weight
8695
:type weight_min: float
8796
:param weight_max: Maximum weight
@@ -120,6 +129,8 @@ class SEQopts:
120129
km_curves: bool = False
121130
ncores: int = multiprocessing.cpu_count()
122131
numerator: Optional[str] = None
132+
offload: bool = False
133+
offload_dir: str = "_seq_models"
123134
parallel: bool = False
124135
plot_colors: List[str] = field(
125136
default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
@@ -136,6 +147,7 @@ class SEQopts:
136147
trial_include: bool = True
137148
visit_colname: str = None
138149
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
150+
weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
139151
weight_min: float = 0.0
140152
weight_max: float = None
141153
weight_lag_condition: bool = True
@@ -195,3 +207,6 @@ def __post_init__(self):
195207
attr = getattr(self, i)
196208
if attr is not None and not isinstance(attr, list):
197209
setattr(self, i, "".join(attr.split()))
210+
211+
if self.offload:
212+
os.makedirs(self.offload_dir, exist_ok=True)

pySEQTarget/SEQuential.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
_subgroup_fit)
1313
from .error import _data_checker, _param_checker
1414
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
15-
from .helpers import _col_string, _format_time, bootstrap_loop
15+
from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
1616
from .initialization import (_cense_denominator, _cense_numerator,
1717
_denominator, _numerator, _outcome)
1818
from .plot import _survival_plot
1919
from .SEQopts import SEQopts
2020
from .SEQoutput import SEQoutput
2121
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
22-
_fit_visit, _weight_bind, _weight_predict,
23-
_weight_setup, _weight_stats)
22+
_fit_visit, _offload_weights, _weight_bind,
23+
_weight_predict, _weight_setup, _weight_stats)
2424

2525

2626
class SEQuential:
@@ -84,6 +84,8 @@ def __init__(
8484
np.random.RandomState(self.seed) if self.seed is not None else np.random
8585
)
8686

87+
self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)
88+
8789
if self.covariates is None:
8890
self.covariates = _outcome(self)
8991

@@ -201,6 +203,9 @@ def fit(self) -> None:
201203
raise ValueError(
202204
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
203205
)
206+
boot_idx = None
207+
if hasattr(self, "_current_boot_idx"):
208+
boot_idx = self._current_boot_idx
204209

205210
if self.weighted:
206211
WDT = _weight_setup(self)
@@ -217,6 +222,9 @@ def fit(self) -> None:
217222
_fit_numerator(self, WDT)
218223
_fit_denominator(self, WDT)
219224

225+
if self.offload:
226+
_offload_weights(self, boot_idx)
227+
220228
WDT = pl.from_pandas(WDT)
221229
WDT = _weight_predict(self, WDT)
222230
_weight_bind(self, WDT)
@@ -244,6 +252,11 @@ def fit(self) -> None:
244252
self.weighted,
245253
"weight",
246254
)
255+
if self.offload:
256+
offloaded_models = {}
257+
for key, model in models.items():
258+
offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
259+
return offloaded_models
247260
return models
248261

249262
def survival(self, **kwargs) -> None:

pySEQTarget/analysis/_hazard.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import polars as pl
55
from lifelines import CoxPHFitter
66

7+
from ..helpers._predict_model import _safe_predict
8+
79

810
def _calculate_hazard(self):
911
if self.subgroup_colname is None:
@@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
9395
else:
9496
model_dict = self.outcome_model[boot_idx]
9597

96-
outcome_model = model_dict["outcome"]
97-
ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
98+
outcome_model = self._offloader.load_model(model_dict["outcome"])
99+
ce_model = None
100+
if self.compevent_colname and "compevent" in model_dict:
101+
ce_model = self._offloader.load_model(model_dict["compevent"])
98102

99103
all_treatments = []
100104
for val in self.treatment_level:
@@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
103107
)
104108

105109
tmp_pd = tmp.to_pandas()
106-
outcome_prob = outcome_model.predict(tmp_pd)
110+
outcome_prob = _safe_predict(outcome_model, tmp_pd)
107111
outcome_sim = rng.binomial(1, outcome_prob)
108112

109113
tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
110114

111115
if ce_model is not None:
112-
ce_prob = ce_model.predict(tmp_pd)
116+
ce_tmp_pd = tmp.to_pandas()
117+
ce_prob = _safe_predict(ce_model, ce_tmp_pd)
113118
ce_sim = rng.binomial(1, ce_prob)
114119
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
115120

pySEQTarget/analysis/_survival_pred.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import polars as pl
22

3+
from ..helpers._predict_model import _safe_predict
4+
35

46
def _get_outcome_predictions(self, TxDT, idx=None):
57
data = TxDT.to_pandas()
@@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):
911

1012
for boot_model in self.outcome_model:
1113
model_dict = boot_model[idx] if idx is not None else boot_model
12-
predictions["outcome"].append(model_dict["outcome"].predict(data))
14+
outcome_model = self._offloader.load_model(model_dict["outcome"])
15+
predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
16+
1317
if self.compevent_colname is not None:
14-
predictions["compevent"].append(model_dict["compevent"].predict(data))
18+
compevent_model = self._offloader.load_model(model_dict["compevent"])
19+
predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
1520

1621
return predictions
1722

pySEQTarget/helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._bootstrap import bootstrap_loop as bootstrap_loop
22
from ._col_string import _col_string as _col_string
33
from ._format_time import _format_time as _format_time
4+
from ._offloader import Offloader as Offloader
45
from ._output_files import _build_md as _build_md
56
from ._output_files import _build_pdf as _build_pdf
67
from ._pad import _pad as _pad

pySEQTarget/helpers/_bootstrap.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
3939
obj._rng = (
4040
np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
4141
)
42+
original_DT = obj._offloader.load_dataframe(original_DT)
4243
obj.DT = _prepare_boot_data(obj, original_DT, i)
44+
del original_DT
45+
obj._current_boot_idx = i + 1
4346

4447
# Disable bootstrapping to prevent recursion
4548
obj.bootstrap_nboot = 0
@@ -60,6 +63,7 @@ def wrapper(self, *args, **kwargs):
6063
results = []
6164
original_DT = self.DT
6265

66+
self._current_boot_idx = None
6367
full = method(self, *args, **kwargs)
6468
results.append(full)
6569

@@ -71,17 +75,20 @@ def wrapper(self, *args, **kwargs):
7175
seed = getattr(self, "seed", None)
7276
method_name = method.__name__
7377

78+
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
79+
7480
if getattr(self, "parallel", False):
7581
original_rng = getattr(self, "_rng", None)
7682
self._rng = None
83+
self.DT = None
7784

7885
with ProcessPoolExecutor(max_workers=ncores) as executor:
7986
futures = [
8087
executor.submit(
8188
_bootstrap_worker,
8289
self,
8390
method_name,
84-
original_DT,
91+
original_DT_ref,
8592
i,
8693
seed,
8794
args,
@@ -95,13 +102,21 @@ def wrapper(self, *args, **kwargs):
95102
results.append(j.result())
96103

97104
self._rng = original_rng
105+
self.DT = self._offloader.load_dataframe(original_DT_ref)
98106
else:
107+
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
108+
del original_DT
99109
for i in tqdm(range(nboot), desc="Bootstrapping..."):
100-
self.DT = _prepare_boot_data(self, original_DT, i)
110+
self._current_boot_idx = i + 1
111+
tmp = self._offloader.load_dataframe(original_DT_ref)
112+
self.DT = _prepare_boot_data(self, tmp, i)
113+
del tmp
114+
self.bootstrap_nboot = 0
101115
boot_fit = method(self, *args, **kwargs)
102116
results.append(boot_fit)
103117

104-
self.DT = original_DT
118+
self.bootstrap_nboot = nboot
119+
self.DT = self._offloader.load_dataframe(original_DT_ref)
105120

106121
end = time.perf_counter()
107122
self._model_time = _format_time(start, end)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def _fix_categories_for_predict(model, newdata):
2+
"""
3+
Fix categorical column ordering in newdata to match what the model expects.
4+
"""
5+
if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'):
6+
design_info = model.model.data.design_info
7+
for factor, factor_info in design_info.factor_infos.items():
8+
if factor_info.type == 'categorical':
9+
col_name = factor.name()
10+
if col_name in newdata.columns:
11+
expected_categories = list(factor_info.categories)
12+
newdata[col_name] = newdata[col_name].astype(str)
13+
newdata[col_name] = newdata[col_name].astype('category')
14+
newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories)
15+
return newdata

pySEQTarget/helpers/_offloader.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from pathlib import Path
2+
from typing import Any, Optional, Union
3+
4+
import joblib
5+
import polars as pl
6+
7+
8+
class Offloader:
9+
"""Manages disk-based storage for models and intermediate data"""
10+
11+
def __init__(self, enabled: bool, dir: str, compression: int = 3):
12+
self.enabled = enabled
13+
self.dir = Path(dir)
14+
self.compression = compression
15+
16+
def save_model(
17+
self, model: Any, name: str, boot_idx: Optional[int] = None
18+
) -> Union[Any, str]:
19+
"""Save a fitted model to disk and return a reference"""
20+
if not self.enabled:
21+
return model
22+
23+
filename = (
24+
f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl"
25+
)
26+
filepath = self.dir / filename
27+
28+
joblib.dump(model, filepath, compress=self.compression)
29+
30+
return str(filepath)
31+
32+
def load_model(self, ref: Union[Any, str]) -> Any:
33+
if not self.enabled or not isinstance(ref, str):
34+
return ref
35+
36+
return joblib.load(ref)
37+
38+
def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
39+
if not self.enabled:
40+
return df
41+
42+
filename = f"{name}.parquet"
43+
filepath = self.dir / filename
44+
45+
df.write_parquet(filepath, compression="zstd")
46+
47+
return str(filepath)
48+
49+
def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame:
50+
if not self.enabled or not isinstance(ref, str):
51+
return ref
52+
53+
return pl.read_parquet(ref)
Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,57 @@
1+
import warnings
2+
13
import numpy as np
24

5+
from ._fix_categories import _fix_categories_for_predict
6+
7+
8+
def _safe_predict(model, data, clip_probs=True):
9+
"""
10+
Predict with category fix fallback if needed.
11+
12+
Parameters
13+
----------
14+
model : statsmodels model
15+
Fitted model object
16+
data : pandas DataFrame
17+
Data to predict on
18+
clip_probs : bool
19+
If True, clip probabilities to [0, 1] and replace NaN with 0.5
20+
"""
21+
data = data.copy()
22+
23+
try:
24+
probs = model.predict(data)
25+
except Exception as e:
26+
if "mismatching levels" in str(e):
27+
data = _fix_categories_for_predict(model, data)
28+
probs = model.predict(data)
29+
else:
30+
raise
31+
32+
if clip_probs:
33+
probs = np.array(probs)
34+
if np.any(np.isnan(probs)):
35+
warnings.warn("NaN values in predicted probabilities, replacing with 0.5")
36+
probs = np.where(np.isnan(probs), 0.5, probs)
37+
probs = np.clip(probs, 0, 1)
38+
39+
return probs
40+
341

442
def _predict_model(self, model, newdata):
543
newdata = newdata.to_pandas()
44+
45+
# Original behavior - convert fixed_cols to category
646
for col in self.fixed_cols:
747
if col in newdata.columns:
848
newdata[col] = newdata[col].astype("category")
9-
return np.array(model.predict(newdata))
49+
50+
try:
51+
return np.array(model.predict(newdata))
52+
except Exception as e:
53+
if "mismatching levels" in str(e):
54+
newdata = _fix_categories_for_predict(model, newdata)
55+
return np.array(model.predict(newdata))
56+
else:
57+
raise

0 commit comments

Comments
 (0)