diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4c0f474..483c77c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -33,7 +33,7 @@ jobs: python -m build - name: Upload distributions - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: release-dists path: dist/ @@ -59,7 +59,7 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: release-dists path: dist/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml index fc64a37..638b505 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -19,4 +19,6 @@ sphinx: # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html python: install: - - requirements: docs/requirements.txt \ No newline at end of file + - requirements: docs/requirements.txt + - method: pip + path: . \ No newline at end of file diff --git a/README.md b/README.md index 5972e06..6329603 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,6 @@ The primary API, `SEQuential` uses a dataclass system to handle function input. From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method. ```python -import polars as pl from pySEQTarget import SEQuential, SEQopts from pySEQTarget.data import load_data diff --git a/docs/conf.py b/docs/conf.py index c366c95..4a3648d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,9 +10,10 @@ import sys from datetime import date -version = importlib.metadata.version("pySEQTarget") -if not version: - version = "0.12.2" +try: + version = importlib.metadata.version("pySEQTarget") +except importlib.metadata.PackageNotFoundError: + version = "unknown" sys.path.insert(0, os.path.abspath("../")) project = "pySEQTarget" diff --git a/docs/requirements.txt b/docs/requirements.txt index bcbda13..91237a9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,11 +3,3 @@ piccolo_theme sphinx-autodoc-typehints sphinx-copybutton myst-parser -numpy -polars -tqdm -statsmodels -matplotlib -pyarrow -lifelines -pySEQTarget \ No newline at end of file diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index a6d3982..45085f2 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -9,7 +9,7 @@ class SEQopts: """ Parameter builder for ``pySEQTarget.SEQuential`` analysis - :param bootstrap_nboot: Number of bootstraps to preform + :param bootstrap_nboot: Number of bootstraps to perform :type bootstrap_nboot: int :param bootstrap_sample: Subsampling proportion of ID-Trials gathered for each bootstrapping iteration :type bootstrap_sample: float @@ -51,7 +51,7 @@ class SEQopts: :param indicator_baseline: How to indicate baseline columns in models :type indicator_baseline: str :param indicator_squared: How to indicate squared columns in models - :type indicator_baseline: str + :type indicator_squared: str :param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates :type km_curves: bool :param ncores: Number of cores to use if running in parallel diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 7e3b551..22efef8 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -176,7 +176,7 @@ def bootstrap(self, **kwargs) -> None: "bootstrap_nboot", "bootstrap_sample", "bootstrap_CI", - "bootstrap_method", + "bootstrap_CI_method", } for key, value in kwargs.items(): if key in allowed: @@ -274,7 +274,7 @@ def survival(self, **kwargs) -> None: if key in allowed: setattr(self, key, val) else: - raise ValueError(f"Unknown or misplaced arugment: {key}") + raise ValueError(f"Unknown or misplaced argument: {key}") if not hasattr(self, "outcome_model") or not self.outcome_model: raise ValueError( @@ -315,7 +315,7 @@ def plot(self, **kwargs) -> None: if key in allowed: setattr(self, key, val) else: - raise ValueError(f"Unknown or misplaced arugment: {key}") + raise ValueError(f"Unknown or misplaced argument: {key}") self.km_graph = _survival_plot(self) def collect(self) -> SEQoutput: diff --git a/pySEQTarget/error/__init__.py b/pySEQTarget/error/__init__.py index fb19cae..bf1d614 100644 --- a/pySEQTarget/error/__init__.py +++ b/pySEQTarget/error/__init__.py @@ -1,7 +1,9 @@ +from ._check_separation import _check_separation from ._data_checker import _data_checker from ._param_checker import _param_checker __all__ = [ + "_check_separation", "_data_checker", "_param_checker", ] diff --git a/pySEQTarget/error/_check_separation.py b/pySEQTarget/error/_check_separation.py new file mode 100644 index 0000000..c30b5fc --- /dev/null +++ b/pySEQTarget/error/_check_separation.py @@ -0,0 +1,23 @@ +import warnings + +import numpy as np + + +def _check_separation(model_fit, label="model"): + """ + Check for perfect or quasi-complete separation in a fitted logistic regression model. + Issues a warning if large (|coef| > 25) or non-finite coefficients are detected, + as these are reliable indicators of separation in logistic regression. + """ + params = np.array(model_fit.params).flatten() + + has_large = np.any(np.abs(params) > 25) + has_nonfinite = np.any(~np.isfinite(params)) + + if has_large or has_nonfinite: + warnings.warn( + f"Possible perfect or quasi-complete separation detected in {label}. " + "The resulting weights may be unreliable.", + UserWarning, + stacklevel=2, + ) diff --git a/pySEQTarget/initialization/_outcome.py b/pySEQTarget/initialization/_outcome.py index 4ec0fc9..cd693f5 100644 --- a/pySEQTarget/initialization/_outcome.py +++ b/pySEQTarget/initialization/_outcome.py @@ -6,11 +6,11 @@ def _outcome(self) -> str: ["followup*dose", f"followup*dose{self.indicator_squared}"] ) - if self.hazard or not self.km_curves: + if self.hazard_estimate or not self.km_curves: interaction = interaction_dose = None tv_bas = ( - "+".join([f"{v}_bas" for v in self.time_varying_cols]) + "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols]) if self.time_varying_cols else None ) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index 9a362ec..e6b39d0 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -1,6 +1,8 @@ import statsmodels.api as sm import statsmodels.formula.api as smf +from ..error._check_separation import _check_separation + def _get_subset_for_level( self, WDT, level_idx, level, tx_lag_col, exclude_followup_zero=False @@ -41,9 +43,14 @@ def _fit_pair( WDT = WDT[WDT[_eligible_col] == 1] for rhs, out in zip(formula_attr, output_attrs): + if len(WDT[outcome].unique()) < 2: + setattr(self, out, None) + continue formula = f"{outcome}~{rhs}" model = smf.glm(formula, WDT, family=sm.families.Binomial()) - setattr(self, out, model.fit(disp=0, method=self.weight_fit_method)) + fitted = model.fit(disp=0, method=self.weight_fit_method) + _check_separation(fitted, label=out.replace("_model", "").replace("_", " ")) + setattr(self, out, fitted) def _fit_LTFU(self, WDT): @@ -91,12 +98,16 @@ def _fit_numerator(self, WDT): is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" for i, level in enumerate(self.treatment_level): DT_subset = _get_subset_for_level(self, WDT, i, level, tx_lag_col) + if len(DT_subset[predictor].unique()) < 2: + fits.append(None) + continue # Use logit for binary 0/1 censoring, mnlogit otherwise if is_binary: model = smf.logit(formula, DT_subset) else: model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0, method=self.weight_fit_method) + _check_separation(model_fit, label=f"numerator (level {level})") fits.append(model_fit) self.numerator_model = fits @@ -125,12 +136,16 @@ def _fit_denominator(self, WDT): DT_subset = _get_subset_for_level( self, WDT, i, level, "tx_lag", exclude_followup_zero=exclude_followup_zero ) + if len(DT_subset[predictor].unique()) < 2: + fits.append(None) + continue # Use logit for binary 0/1 censoring, mnlogit otherwise if is_binary: model = smf.logit(formula, DT_subset) else: model = smf.mnlogit(formula, DT_subset) model_fit = model.fit(disp=0, method=self.weight_fit_method) + _check_separation(model_fit, label=f"denominator (level {level})") fits.append(model_fit) self.denominator_model = fits diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index 7b1c20c..c865e2e 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -170,33 +170,42 @@ def _weight_predict(self, WDT): if self.cense_colname is not None: cense_num_model = self._offloader.load_model(self.cense_numerator_model) cense_denom_model = self._offloader.load_model(self.cense_denominator_model) - p_num = _predict_model(self, cense_num_model, WDT).flatten() - p_denom = _predict_model(self, cense_denom_model, WDT).flatten() - WDT = WDT.with_columns( - [ - pl.Series("cense_numerator", p_num), - pl.Series("cense_denominator", p_denom), - ] - ).with_columns( - (pl.col("cense_numerator") / pl.col("cense_denominator")).alias("_cense") - ) + if cense_num_model is not None and cense_denom_model is not None: + p_num = _predict_model(self, cense_num_model, WDT).flatten() + p_denom = _predict_model(self, cense_denom_model, WDT).flatten() + WDT = WDT.with_columns( + [ + pl.Series("cense_numerator", p_num), + pl.Series("cense_denominator", p_denom), + ] + ).with_columns( + (pl.col("cense_numerator") / pl.col("cense_denominator")).alias( + "_cense" + ) + ) + else: + WDT = WDT.with_columns(pl.lit(1.0).alias("_cense")) else: WDT = WDT.with_columns(pl.lit(1.0).alias("_cense")) if self.visit_colname is not None: visit_num_model = self._offloader.load_model(self.visit_numerator_model) visit_denom_model = self._offloader.load_model(self.visit_denominator_model) - p_num = _predict_model(self, visit_num_model, WDT).flatten() - p_denom = _predict_model(self, visit_denom_model, WDT).flatten() - - WDT = WDT.with_columns( - [ - pl.Series("visit_numerator", p_num), - pl.Series("visit_denominator", p_denom), - ] - ).with_columns( - (pl.col("visit_numerator") / pl.col("visit_denominator")).alias("_visit") - ) + if visit_num_model is not None and visit_denom_model is not None: + p_num = _predict_model(self, visit_num_model, WDT).flatten() + p_denom = _predict_model(self, visit_denom_model, WDT).flatten() + WDT = WDT.with_columns( + [ + pl.Series("visit_numerator", p_num), + pl.Series("visit_denominator", p_denom), + ] + ).with_columns( + (pl.col("visit_numerator") / pl.col("visit_denominator")).alias( + "_visit" + ) + ) + else: + WDT = WDT.with_columns(pl.lit(1.0).alias("_visit")) else: WDT = WDT.with_columns(pl.lit(1.0).alias("_visit")) diff --git a/pyproject.toml b/pyproject.toml index 99330f5..f480e67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.2" +version = "0.12.3" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_check_separation.py b/tests/test_check_separation.py new file mode 100644 index 0000000..bb24ac8 --- /dev/null +++ b/tests/test_check_separation.py @@ -0,0 +1,69 @@ +import warnings +from types import SimpleNamespace + +import numpy as np +import pytest + +from pySEQTarget.error._check_separation import _check_separation + + +def _mock_model(params): + """Create a minimal mock model with a .params attribute.""" + return SimpleNamespace(params=np.array(params)) + + +def test_no_warning_for_normal_coefficients(): + model = _mock_model([0.5, -1.2, 3.0, -0.01]) + with warnings.catch_warnings(): + warnings.simplefilter("error") + _check_separation(model) # should not raise + + +def test_warns_for_large_positive_coefficient(): + model = _mock_model([0.5, 26.0]) + with pytest.warns(UserWarning, match="separation"): + _check_separation(model) + + +def test_warns_for_large_negative_coefficient(): + model = _mock_model([0.5, -30.0]) + with pytest.warns(UserWarning, match="separation"): + _check_separation(model) + + +def test_warns_for_inf_coefficient(): + model = _mock_model([1.0, np.inf]) + with pytest.warns(UserWarning, match="separation"): + _check_separation(model) + + +def test_warns_for_neg_inf_coefficient(): + model = _mock_model([1.0, -np.inf]) + with pytest.warns(UserWarning, match="separation"): + _check_separation(model) + + +def test_warns_for_nan_coefficient(): + model = _mock_model([1.0, np.nan]) + with pytest.warns(UserWarning, match="separation"): + _check_separation(model) + + +def test_boundary_coefficient_does_not_warn(): + # Exactly 25 should not trigger (threshold is strictly > 25) + model = _mock_model([25.0, -25.0]) + with warnings.catch_warnings(): + warnings.simplefilter("error") + _check_separation(model) + + +def test_label_appears_in_warning(): + model = _mock_model([100.0]) + with pytest.warns(UserWarning, match="censoring numerator"): + _check_separation(model, label="censoring numerator") + + +def test_default_label_appears_in_warning(): + model = _mock_model([np.inf]) + with pytest.warns(UserWarning, match="model"): + _check_separation(model) diff --git a/tests/test_no_variation.py b/tests/test_no_variation.py new file mode 100644 index 0000000..86a607b --- /dev/null +++ b/tests/test_no_variation.py @@ -0,0 +1,147 @@ +from types import SimpleNamespace + +import numpy as np +import pandas as pd + +from pySEQTarget.weighting._weight_fit import (_fit_denominator, + _fit_numerator, _fit_pair) + + +def _mock_self(**overrides): + attrs = { + "weight_preexpansion": False, + "excused": False, + "method": "censoring", + "treatment_col": "tx", + "indicator_baseline": "_bas", + "numerator": "1", + "denominator": "1", + "treatment_level": [0, 1], + "weight_fit_method": "newton", + "weight_lag_condition": True, + "weight_eligible_colnames": [None, None], + "excused_colnames": [None, None], + "cense_colname": "ltfu", + } + attrs.update(overrides) + return SimpleNamespace(**attrs) + + +def _make_data(n=60): + """Both treatment levels have variation in the predictor.""" + block = np.array([0] * (n // 2) + [1] * (n // 2)) + return pd.concat( + [ + pd.DataFrame( + { + "tx": block, + "tx_lag": np.zeros(n, int), + "followup": np.arange(1, n + 1), + } + ), + pd.DataFrame( + { + "tx": block, + "tx_lag": np.ones(n, int), + "followup": np.arange(1, n + 1), + } + ), + ], + ignore_index=True, + ) + + +def _make_data_no_variation_level0(n=60): + """Level-0 subset (tx_lag=0) has all tx=0; level-1 subset has mixed tx.""" + block = np.array([0] * (n // 2) + [1] * (n // 2)) + return pd.concat( + [ + pd.DataFrame( + { + "tx": np.zeros(n, int), + "tx_lag": np.zeros(n, int), + "followup": np.arange(1, n + 1), + } + ), + pd.DataFrame( + { + "tx": block, + "tx_lag": np.ones(n, int), + "followup": np.arange(1, n + 1), + } + ), + ], + ignore_index=True, + ) + + +# ── _fit_numerator ──────────────────────────────────────────────────────────── + + +def test_fit_numerator_stores_none_when_no_variation(): + obj = _mock_self() + _fit_numerator(obj, _make_data_no_variation_level0()) + assert obj.numerator_model[0] is None + assert obj.numerator_model[1] is not None + + +def test_fit_numerator_fits_when_variation_exists(): + obj = _mock_self() + _fit_numerator(obj, _make_data()) + assert obj.numerator_model[0] is not None + assert obj.numerator_model[1] is not None + + +# ── _fit_denominator ────────────────────────────────────────────────────────── + + +def test_fit_denominator_stores_none_when_no_variation(): + obj = _mock_self() + _fit_denominator(obj, _make_data_no_variation_level0()) + assert obj.denominator_model[0] is None + assert obj.denominator_model[1] is not None + + +def test_fit_denominator_fits_when_variation_exists(): + obj = _mock_self() + _fit_denominator(obj, _make_data()) + assert obj.denominator_model[0] is not None + assert obj.denominator_model[1] is not None + + +# ── _fit_pair (cense / visit models) ───────────────────────────────────────── + + +def test_fit_pair_stores_none_when_no_variation(): + n = 60 + df = pd.DataFrame({"ltfu": np.zeros(n, int), "followup": np.arange(n)}) + obj = _mock_self() + _fit_pair( + obj, + df, + "cense_colname", + ["1", "1"], + ["cense_numerator_model", "cense_denominator_model"], + ) + assert obj.cense_numerator_model is None + assert obj.cense_denominator_model is None + + +def test_fit_pair_fits_when_variation_exists(): + n = 60 + df = pd.DataFrame( + { + "ltfu": np.array([0] * (n // 2) + [1] * (n // 2)), + "followup": np.arange(n), + } + ) + obj = _mock_self() + _fit_pair( + obj, + df, + "cense_colname", + ["1", "1"], + ["cense_numerator_model", "cense_denominator_model"], + ) + assert obj.cense_numerator_model is not None + assert obj.cense_denominator_model is not None