diff --git a/afqinsight/datasets.py b/afqinsight/datasets.py index 275b1267..e4d915cd 100755 --- a/afqinsight/datasets.py +++ b/afqinsight/datasets.py @@ -662,7 +662,7 @@ def from_study(study, verbose=None): dataset_kwargs = { "sarica": { "dwi_metrics": ["md", "fa"], - "target_cols": ["class"], + "target_cols": ["class", "age"], "label_encode_cols": ["class"], }, "weston-havens": {"dwi_metrics": ["md", "fa"], "target_cols": ["Age"]}, diff --git a/afqinsight/parametric.py b/afqinsight/parametric.py index 0eef914f..aceac8c9 100644 --- a/afqinsight/parametric.py +++ b/afqinsight/parametric.py @@ -1,4 +1,4 @@ -"""Perform linear modeling at leach node along the tract.""" +"""Perform linear modeling at each node along the tract.""" import numpy as np import pandas as pd @@ -11,11 +11,11 @@ def node_wise_regression( afq_dataset, tract, - metric, formula, - group="group", + group=None, lme=False, rand_eff="subjectID", + impute="median", ): """Model group differences using node-wise regression along the length of the tract. @@ -26,13 +26,10 @@ def node_wise_regression( ---------- afq_dataset: AFQDataset Loaded AFQDataset object + tract: str String specifying the tract to model - metric: str - String specifying which diffusion metric to use as an outcome - eg. 'fa' - formula: str An R-style formula specifying the regression model to fit at each node. This can take the form @@ -46,6 +43,9 @@ def node_wise_regression( mixed-effects models. If using anything other than the default value, this column must be present in the 'target_cols' of the AFQDataset object + impute: str or None, default='median' + String specifying the imputation strategy to use for missing data. + Returns ------- @@ -53,13 +53,13 @@ def node_wise_regression( A dictionary with the following key-value pairs: {'tract': tract, - 'reference_coefs': coefs_default, - 'group_coefs': coefs_treat, - 'reference_CI': cis_default, - 'group_CI': cis_treat, - 'pvals': pvals, - 'reject_idx': reject_idx, - 'model_fits': fits} + 'reference_coefs': coefs_default, + 'group_coefs': coefs_treat, + 'reference_CI': cis_default, + 'group_CI': cis_treat, + 'pvals': pvals, + 'reject_idx': reject_idx, + 'model_fits': fits} tract: str The tract described by this dictionary @@ -72,7 +72,7 @@ def node_wise_regression( group_coefs: list of floats A list of beta-weights representing the average group effect metric for the treatment group on a diffusion metric at a given location - along the tract + along the tract, if group None this will be a list of zeros. reference_CI: np.array of np.array A numpy array containing a series of numpy arrays indicating the @@ -82,7 +82,8 @@ def node_wise_regression( group_CI: np.array of np.array A numpy array containing a series of numpy arrays indicating the 95% confidence interval around the estimated beta-weight of the - treatment effect at a given location along the tract + treatment effect at a given location along the tract. If group is + None, this will be an array of zeros. pvals: list of floats A list of p-values testing whether or not the beta-weight of the @@ -96,8 +97,13 @@ def node_wise_regression( A list of the statsmodels object fit along the length of the nodes """ - X = SimpleImputer(strategy="median").fit_transform(afq_dataset.X) - afq_dataset.target_cols[0] = group + if impute is not None: + X = SimpleImputer(strategy=impute).fit_transform(afq_dataset.X) + + if group is not None: + afq_dataset.target_cols[0] = group + + metric = formula.split("~")[0].strip() tract_data = ( pd.DataFrame(columns=afq_dataset.feature_names, data=X) @@ -106,12 +112,13 @@ def node_wise_regression( ) pvals = np.zeros(tract_data.shape[-1]) + pvals_corrected = np.zeros(tract_data.shape[-1]) coefs_default = np.zeros(tract_data.shape[-1]) coefs_treat = np.zeros(tract_data.shape[-1]) cis_default = np.zeros((tract_data.shape[-1], 2)) cis_treat = np.zeros((tract_data.shape[-1], 2)) + reject = np.zeros(tract_data.shape[-1], dtype=bool) fits = {} - # Loop through each node and fit model for ii, column in enumerate(tract_data.columns): # fit linear mixed-effects model @@ -125,7 +132,6 @@ def node_wise_regression( model = smf.mixedlm(formula, this, groups=rand_eff) fit = model.fit() - fits[column] = fit # fit OLS model else: @@ -135,31 +141,76 @@ def node_wise_regression( model = OLS.from_formula(formula, this) fit = model.fit() - fits[column] = fit - + fits[ii] = fit # pull out coefficients, CIs, and p-values from our model coefs_default[ii] = fit.params.filter(regex="Intercept", axis=0).iloc[0] - coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0] - - cis_default[ii] = ( - fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values - ) - cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values - pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0] - - # Correct p-values for multiple comparisons - reject, pval_corrected, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh") - reject_idx = np.where(reject) - - tract_dict = { - "tract": tract, - "reference_coefs": coefs_default, - "group_coefs": coefs_treat, - "reference_CI": cis_default, - "group_CI": cis_treat, - "pvals": pvals, - "reject_idx": reject_idx, - "model_fits": fits, - } - - return tract_dict + + if group is not None: + coefs_treat[ii] = fit.params.filter(regex=group, axis=0).iloc[0] + + cis_default[ii] = ( + fit.conf_int(alpha=0.05).filter(regex="Intercept", axis=0).values + ) + cis_treat[ii] = fit.conf_int(alpha=0.05).filter(regex=group, axis=0).values + pvals[ii] = fit.pvalues.filter(regex=group, axis=0).iloc[0] + + # Correct p-values for multiple comparisons + reject, pvals_corrected, _, _ = multipletests( + pvals, alpha=0.05, method="fdr_bh" + ) + + reject = np.where(reject, 1, 0) + + return pd.DataFrame( + { + "reference_coefs": coefs_default, + "group_coefs": coefs_treat, + "reference_CI_lb": cis_default[:, 0], + "reference_CI_ub": cis_default[:, 1], + "group_CI_lb": cis_treat[:, 0], + "group_CI_ub": cis_treat[:, 1], + "pvals": pvals, + "pvals_corrected": pvals_corrected, + "reject_idx": reject, + } + ), fits + + +class RegressionResults(object): + def __init__(self, kwargs): + self.tract = kwargs.get("tract", None) + self.reference_coefs = kwargs.get("reference_coefs", None) + self.group_coefs = kwargs.get("group_coefs", None) + self.reference_ci = kwargs.get("reference_ci", None) + self.group_ci = kwargs.get("group_ci", None) + self.pvals = kwargs.get("pvals", None) + self.pvals_corrected = kwargs.get("pvals_corrected", None) + self.reject_idx = kwargs.get("reject_idx", None) + self.model_fits = kwargs.get("model_fits", None) + + +class NodeWiseRegression(object): + def __init__(self, formula, lme=False): + self.formula = formula + self.lme = lme + + def fit(self, dataset, tracts, group=None, rand_eff="subjectID"): + self.result_ = {} + for tract in tracts: + self.result_[tract] = node_wise_regression( + dataset, + tract, + self.formula, + lme=self.lme, + group=group, + rand_eff=rand_eff, + ) + self.is_fitted = True + return self + + def predict(self, dataset, tract, metric, group="group", rand_eff="subjectID"): + if not self.is_fitted: + raise ValueError("Model not fitted yet. Please call fit() method first.") + result = self.result_.get(tract, None) + if result is None: + raise ValueError(f"Tract {tract} not found in the fitted model.") diff --git a/afqinsight/tests/test_datasets.py b/afqinsight/tests/test_datasets.py index e5cf31bb..6e957884 100644 --- a/afqinsight/tests/test_datasets.py +++ b/afqinsight/tests/test_datasets.py @@ -358,7 +358,7 @@ def test_from_study(study): "n_subjects": 48, "n_features": 4000, "n_groups": 40, - "target_cols": ["class"], + "target_cols": ["class", "age"], }, "weston-havens": { "n_subjects": 77, diff --git a/afqinsight/tests/test_parametric.py b/afqinsight/tests/test_parametric.py new file mode 100644 index 00000000..8c5e8fe2 --- /dev/null +++ b/afqinsight/tests/test_parametric.py @@ -0,0 +1,43 @@ +import numpy as np + +from afqinsight import AFQDataset +from afqinsight.parametric import NodeWiseRegression, node_wise_regression + + +def test_node_wise_regression(): + # Store results + group_dict = {} + group_age_dict = {} + age_dict = {} + + data = AFQDataset.from_study("sarica") + tracts = ["Right Corticospinal", "Right SLF"] + for tract in tracts: + for lme in [True, False]: + # Run different versions of this: with age, without age, only with + # age: + + group_dict[tract] = node_wise_regression( + data, tract, "fa ~ C(group)", lme=lme, group="group" + ) + group_age_dict[tract] = node_wise_regression( + data, tract, "fa ~ C(group) + age", lme=lme, group="group" + ) + age_dict[tract] = node_wise_regression(data, tract, "fa ~ age", lme=lme) + + assert group_dict[tract]["pvals"].shape == (100,) + assert group_age_dict[tract]["pvals"].shape == (100,) + assert age_dict[tract]["pvals"].shape == (100,) + + assert np.any(group_dict["Right Corticospinal"]["pvals_corrected"] < 0.05) + assert np.all(group_dict["Right SLF"]["pvals_corrected"] > 0.05) + assert np.any(group_age_dict["Right Corticospinal"]["pvals_corrected"] < 0.05) + assert np.all(group_age_dict["Right SLF"]["pvals_corrected"] > 0.05) + + +def test_NodeWiseRegression(): + data = AFQDataset.from_study("sarica") + tracts = ["Left Corticospinal", "Left SLF"] + for lme in [True, False]: + model = NodeWiseRegression("fa ~ C(group) + age", lme=lme) + model.fit(data, tracts, group="group") diff --git a/examples/plot_als_classification.py b/examples/plot_als_classification.py index 8277cf02..ca2b2e4b 100644 --- a/examples/plot_als_classification.py +++ b/examples/plot_als_classification.py @@ -54,6 +54,7 @@ X = afqdata.X y = afqdata.y.astype(float) # SGL expects float targets +is_als = y[:, 0] groups = afqdata.groups feature_names = afqdata.feature_names group_names = afqdata.group_names @@ -117,7 +118,7 @@ # scikit-learn functions scores = cross_validate( - pipe, X, y, cv=5, return_train_score=True, return_estimator=True + pipe, X, is_als, cv=5, return_train_score=True, return_estimator=True ) # Display results diff --git a/examples/plot_als_comparison.py b/examples/plot_als_comparison.py index 2e804f33..7c78b565 100644 --- a/examples/plot_als_comparison.py +++ b/examples/plot_als_comparison.py @@ -72,12 +72,12 @@ # Loop through the data and generate plots -for i, tract in enumerate(tracts): +for ii, tract in enumerate(tracts): # fit node-wise regression for each tract based on model formula - tract_dict = node_wise_regression(afqdata, tract, "fa", "fa ~ C(group)") + tract_dict = node_wise_regression(afqdata, tract, "fa ~ C(group)", group="group") - row = i // num_cols - col = i % num_cols + row = ii // num_cols + col = ii % num_cols axes[row][col].set_title(tract)