Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions easy_pipeline/easy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#!/usr/bin/env python
# coding=utf-8

"""
High-level pipeline wrapper for hyperscanning analysis.
"""

from collections import namedtuple
import matplotlib.pyplot as plt
import numpy as np
from . import analyses
from . import stats as hyn_stats
from . import viz

class Results:
"""
Class to store and display results from the hyperscanning pipeline.

Attributes
----------
connectivity : dict
Dictionary storing connectivity matrices.
Keys are formatted as "{metric}_{band}".
stats : dict
Dictionary storing statistical results (if computed).
Keys are formatted as "{metric}_{band}".
canvas : dict
Dictionary storing generated matplotlib figures.
Keys are formatted as "{metric}_{band}_conn" or "{metric}_{band}_stats".
"""
def __init__(self):
self.connectivity = {}
self.stats = {}
self.canvas = {}

def summary(self):
"""Prints a summary of the available results."""
print("Hypyp Analysis Results Summary")
print("==============================\n")

print(f"Connectivity metrics computed: {len(self.connectivity)}")
for key in self.connectivity.keys():
print(f"- {key}: {self.connectivity[key].shape}")

print(f"\nStatistical tests performed: {len(self.stats)}")
for key in self.stats.keys():
print(f"- {key}")

print(f"\nPlots generated: {len(self.canvas)}")
for key in self.canvas.keys():
print(f"- {key}")


def run_tests(data):
"""
Checks validity of the input data (channels, sampling rate, epochs, etc).

Args:
data (list): List of 2 MNE Epochs objects.

Returns:
bool: True if passes, False/Raises error otherwise.
"""
if len(data) != 2:
raise ValueError("Data must be a list of EXACTLY 2 MNE Epochs objects.")

# test electrodes setting
first_set = data[0].ch_names
all_same_set = all(epo.ch_names == first_set for epo in data)
if not all_same_set:
print('⚠️ channels don\'t have the same names for every subject')
return False

# test sampling rate
first_sfreq = data[0].info['sfreq']
all_same_sfreq = all(epo.info['sfreq'] == first_sfreq for epo in data)
if not all_same_sfreq:
print('⚠️ sampling rate isn\'t the same for every subject')
return False

# test epoch number
first_n_meas = len(data[0])
all_same_n_meas = all(len(epo) == first_n_meas for epo in data)
if not all_same_n_meas:
print('⚠️ number of epochs isn\'t the same for every subject')
return False


return True


def run_pipeline(data: list,
metrics: list = ['plv'],
freq_bands: dict = {'Alpha': [8, 12]},
stats: list = None,
plot: bool = True) -> Results:
"""
Run the full pipeline for hyperscanning analysis on a single dyad.

Args:
data (list): 2 Epochs objects from the MNE library.
metrics (list): Metrics to compute, e.g. ["plv", "ccorr"].
freq_bands (dict): Frequency bands to compute, e.g. {"alpha": [8, 12], "beta": [13, 30]}.
stats (list, optional): Stats to compute, e.g. ["perm"]. Defaults to None.
plot (bool, optional): Generate figures. Defaults to True.

Returns:
Results: A Results object containing connectivity matrices, stats, and plots.
"""

results = Results()

# Validation
if not run_tests(data):
raise ValueError("Input data validation failed. Check warnings.")

sampling_rate = data[0].info['sfreq']

# Determine if we need to keep epochs for statistics
# If stats are requested, we set epochs_average=False
compute_stats = stats is not None and len(stats) > 0
epochs_average = not compute_stats

# iterate over metrics
for metric in metrics:
# iterate over frequency bands
for band_name, band_range in freq_bands.items():
print(f"Computing {metric} for {band_name} band...")

# Extract data arrays from Epochs (pair_connectivity expects arrays)
data_list = [d.get_data() for d in data]

# Compute connectivity
# If stats are requested, we keep epochs (epochs_average=False)
con_matrix = analyses.pair_connectivity(
data=data_list,
sampling_rate=sampling_rate,
frequencies={band_name: band_range},
mode=metric,
epochs_average=epochs_average
)

# Store connectivity result (removing frequency dimension of size 1)
con_result = con_matrix[0]

# Always store the AVERAGED connectivity for user access
# If epochs were preserved for stats, average them now for storage
if not epochs_average:
con_mean = np.mean(con_result, axis=0)
else:
con_mean = con_result
results.connectivity[f"{metric}_{band_name}"] = con_mean

# --- Statistics ---
if compute_stats and ('perm' in stats or 'ttest' in stats):
print(f"Running statistics for {metric} {band_name}...")

# Reshape for statsCond: (n_epochs, n_tests, n_freq)
# We treat every connection in the matrix as a test
n_epochs = con_result.shape[0]
n_tests = con_result.shape[1] * con_result.shape[2]
data_for_stats = con_result.reshape(n_epochs, n_tests, 1)

# Run permutation t-test against 0
T_obs, p_values, H0, adj_p, T_obs_plot = hyn_stats.statsCond(
data=data_for_stats,
epochs=data[0],
n_permutations=1000,
alpha=0.05
)

results.stats[f"{metric}_{band_name}"] = {
"T_obs": T_obs,
"p_values": p_values,
"adj_p": adj_p,
"T_obs_plot": T_obs_plot
}

# --- Plotting ---
if plot:
# Get number of channels for extracting inter-brain block
n_ch = len(data[0].ch_names)

# Decide what to plot: Stats (if available) or Connectivity
if compute_stats and ('perm' in stats or 'ttest' in stats):
# T_obs_plot contains T-values for significant connections (others are 0)
# Reshape to full matrix first
full_matrix = T_obs_plot.reshape(con_result.shape[1], con_result.shape[2])
# Extract inter-brain block (same as for connectivity)
con_to_plot = full_matrix[0:n_ch, n_ch:2*n_ch]
subtitle = f"Stats: {metric} - {band_name}"
threshold = 'auto'
else:
# Use the averaged connectivity (already computed above)
# Extract inter-brain block
con_to_plot = con_mean[0:n_ch, n_ch:2*n_ch]
subtitle = f"Connectivity: {metric} - {band_name}"
threshold = 'auto'

# Create the plot using the high-level function
ax = viz.viz_2D_topomap_inter(
data[0],
data[1],
con_to_plot,
threshold=threshold,
steps=10,
lab=True
)

# Get figure from axis to save it
fig = ax.figure
plt.title(subtitle)

results.canvas[f"{metric}_{band_name}"] = fig

return results
306 changes: 306 additions & 0 deletions easy_pipeline/test_easy.ipynb

Large diffs are not rendered by default.