Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
python -m build

- name: Upload distributions
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v7
with:
name: release-dists
path: dist/
Expand All @@ -59,7 +59,7 @@ jobs:

steps:
- name: Retrieve release distributions
uses: actions/download-artifact@v4
uses: actions/download-artifact@v8
with:
name: release-dists
path: dist/
Expand Down
4 changes: 3 additions & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ sphinx:
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt
- requirements: docs/requirements.txt
- method: pip
path: .
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ The primary API, `SEQuential` uses a dataclass system to handle function input.
From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.

```python
import polars as pl
from pySEQTarget import SEQuential, SEQopts
from pySEQTarget.data import load_data

Expand Down
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
import sys
from datetime import date

version = importlib.metadata.version("pySEQTarget")
if not version:
version = "0.12.2"
try:
version = importlib.metadata.version("pySEQTarget")
except importlib.metadata.PackageNotFoundError:
version = "unknown"
sys.path.insert(0, os.path.abspath("../"))

project = "pySEQTarget"
Expand Down
8 changes: 0 additions & 8 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,3 @@ piccolo_theme
sphinx-autodoc-typehints
sphinx-copybutton
myst-parser
numpy
polars
tqdm
statsmodels
matplotlib
pyarrow
lifelines
pySEQTarget
4 changes: 2 additions & 2 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class SEQopts:
"""
Parameter builder for ``pySEQTarget.SEQuential`` analysis

:param bootstrap_nboot: Number of bootstraps to preform
:param bootstrap_nboot: Number of bootstraps to perform
:type bootstrap_nboot: int
:param bootstrap_sample: Subsampling proportion of ID-Trials gathered for each bootstrapping iteration
:type bootstrap_sample: float
Expand Down Expand Up @@ -51,7 +51,7 @@ class SEQopts:
:param indicator_baseline: How to indicate baseline columns in models
:type indicator_baseline: str
:param indicator_squared: How to indicate squared columns in models
:type indicator_baseline: str
:type indicator_squared: str
:param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates
:type km_curves: bool
:param ncores: Number of cores to use if running in parallel
Expand Down
6 changes: 3 additions & 3 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def bootstrap(self, **kwargs) -> None:
"bootstrap_nboot",
"bootstrap_sample",
"bootstrap_CI",
"bootstrap_method",
"bootstrap_CI_method",
}
for key, value in kwargs.items():
if key in allowed:
Expand Down Expand Up @@ -274,7 +274,7 @@ def survival(self, **kwargs) -> None:
if key in allowed:
setattr(self, key, val)
else:
raise ValueError(f"Unknown or misplaced arugment: {key}")
raise ValueError(f"Unknown or misplaced argument: {key}")

if not hasattr(self, "outcome_model") or not self.outcome_model:
raise ValueError(
Expand Down Expand Up @@ -315,7 +315,7 @@ def plot(self, **kwargs) -> None:
if key in allowed:
setattr(self, key, val)
else:
raise ValueError(f"Unknown or misplaced arugment: {key}")
raise ValueError(f"Unknown or misplaced argument: {key}")
self.km_graph = _survival_plot(self)

def collect(self) -> SEQoutput:
Expand Down
2 changes: 2 additions & 0 deletions pySEQTarget/error/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from ._check_separation import _check_separation
from ._data_checker import _data_checker
from ._param_checker import _param_checker

__all__ = [
"_check_separation",
"_data_checker",
"_param_checker",
]
23 changes: 23 additions & 0 deletions pySEQTarget/error/_check_separation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import warnings

import numpy as np


def _check_separation(model_fit, label="model"):
"""
Check for perfect or quasi-complete separation in a fitted logistic regression model.
Issues a warning if large (|coef| > 25) or non-finite coefficients are detected,
as these are reliable indicators of separation in logistic regression.
"""
params = np.array(model_fit.params).flatten()

has_large = np.any(np.abs(params) > 25)
has_nonfinite = np.any(~np.isfinite(params))

if has_large or has_nonfinite:
warnings.warn(
f"Possible perfect or quasi-complete separation detected in {label}. "
"The resulting weights may be unreliable.",
UserWarning,
stacklevel=2,
)
4 changes: 2 additions & 2 deletions pySEQTarget/initialization/_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ def _outcome(self) -> str:
["followup*dose", f"followup*dose{self.indicator_squared}"]
)

if self.hazard or not self.km_curves:
if self.hazard_estimate or not self.km_curves:
interaction = interaction_dose = None

tv_bas = (
"+".join([f"{v}_bas" for v in self.time_varying_cols])
"+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
if self.time_varying_cols
else None
)
Expand Down
17 changes: 16 additions & 1 deletion pySEQTarget/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import statsmodels.api as sm
import statsmodels.formula.api as smf

from ..error._check_separation import _check_separation


def _get_subset_for_level(
self, WDT, level_idx, level, tx_lag_col, exclude_followup_zero=False
Expand Down Expand Up @@ -41,9 +43,14 @@ def _fit_pair(
WDT = WDT[WDT[_eligible_col] == 1]

for rhs, out in zip(formula_attr, output_attrs):
if len(WDT[outcome].unique()) < 2:
setattr(self, out, None)
continue
formula = f"{outcome}~{rhs}"
model = smf.glm(formula, WDT, family=sm.families.Binomial())
setattr(self, out, model.fit(disp=0, method=self.weight_fit_method))
fitted = model.fit(disp=0, method=self.weight_fit_method)
_check_separation(fitted, label=out.replace("_model", "").replace("_", " "))
setattr(self, out, fitted)


def _fit_LTFU(self, WDT):
Expand Down Expand Up @@ -91,12 +98,16 @@ def _fit_numerator(self, WDT):
is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring"
for i, level in enumerate(self.treatment_level):
DT_subset = _get_subset_for_level(self, WDT, i, level, tx_lag_col)
if len(DT_subset[predictor].unique()) < 2:
fits.append(None)
continue
# Use logit for binary 0/1 censoring, mnlogit otherwise
if is_binary:
model = smf.logit(formula, DT_subset)
else:
model = smf.mnlogit(formula, DT_subset)
model_fit = model.fit(disp=0, method=self.weight_fit_method)
_check_separation(model_fit, label=f"numerator (level {level})")
fits.append(model_fit)

self.numerator_model = fits
Expand Down Expand Up @@ -125,12 +136,16 @@ def _fit_denominator(self, WDT):
DT_subset = _get_subset_for_level(
self, WDT, i, level, "tx_lag", exclude_followup_zero=exclude_followup_zero
)
if len(DT_subset[predictor].unique()) < 2:
fits.append(None)
continue
# Use logit for binary 0/1 censoring, mnlogit otherwise
if is_binary:
model = smf.logit(formula, DT_subset)
else:
model = smf.mnlogit(formula, DT_subset)
model_fit = model.fit(disp=0, method=self.weight_fit_method)
_check_separation(model_fit, label=f"denominator (level {level})")
fits.append(model_fit)

self.denominator_model = fits
Expand Down
51 changes: 30 additions & 21 deletions pySEQTarget/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,33 +170,42 @@ def _weight_predict(self, WDT):
if self.cense_colname is not None:
cense_num_model = self._offloader.load_model(self.cense_numerator_model)
cense_denom_model = self._offloader.load_model(self.cense_denominator_model)
p_num = _predict_model(self, cense_num_model, WDT).flatten()
p_denom = _predict_model(self, cense_denom_model, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias("_cense")
)
if cense_num_model is not None and cense_denom_model is not None:
p_num = _predict_model(self, cense_num_model, WDT).flatten()
p_denom = _predict_model(self, cense_denom_model, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias(
"_cense"
)
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_cense"))
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_cense"))

if self.visit_colname is not None:
visit_num_model = self._offloader.load_model(self.visit_numerator_model)
visit_denom_model = self._offloader.load_model(self.visit_denominator_model)
p_num = _predict_model(self, visit_num_model, WDT).flatten()
p_denom = _predict_model(self, visit_denom_model, WDT).flatten()

WDT = WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
).with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias("_visit")
)
if visit_num_model is not None and visit_denom_model is not None:
p_num = _predict_model(self, visit_num_model, WDT).flatten()
p_denom = _predict_model(self, visit_denom_model, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
).with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias(
"_visit"
)
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_visit"))
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_visit"))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.2"
version = "0.12.3"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
69 changes: 69 additions & 0 deletions tests/test_check_separation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import warnings
from types import SimpleNamespace

import numpy as np
import pytest

from pySEQTarget.error._check_separation import _check_separation


def _mock_model(params):
"""Create a minimal mock model with a .params attribute."""
return SimpleNamespace(params=np.array(params))


def test_no_warning_for_normal_coefficients():
model = _mock_model([0.5, -1.2, 3.0, -0.01])
with warnings.catch_warnings():
warnings.simplefilter("error")
_check_separation(model) # should not raise


def test_warns_for_large_positive_coefficient():
model = _mock_model([0.5, 26.0])
with pytest.warns(UserWarning, match="separation"):
_check_separation(model)


def test_warns_for_large_negative_coefficient():
model = _mock_model([0.5, -30.0])
with pytest.warns(UserWarning, match="separation"):
_check_separation(model)


def test_warns_for_inf_coefficient():
model = _mock_model([1.0, np.inf])
with pytest.warns(UserWarning, match="separation"):
_check_separation(model)


def test_warns_for_neg_inf_coefficient():
model = _mock_model([1.0, -np.inf])
with pytest.warns(UserWarning, match="separation"):
_check_separation(model)


def test_warns_for_nan_coefficient():
model = _mock_model([1.0, np.nan])
with pytest.warns(UserWarning, match="separation"):
_check_separation(model)


def test_boundary_coefficient_does_not_warn():
# Exactly 25 should not trigger (threshold is strictly > 25)
model = _mock_model([25.0, -25.0])
with warnings.catch_warnings():
warnings.simplefilter("error")
_check_separation(model)


def test_label_appears_in_warning():
model = _mock_model([100.0])
with pytest.warns(UserWarning, match="censoring numerator"):
_check_separation(model, label="censoring numerator")


def test_default_label_appears_in_warning():
model = _mock_model([np.inf])
with pytest.warns(UserWarning, match="model"):
_check_separation(model)
Loading