Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Python/shapeworks/shapeworks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from shapeworks_py import *
from .conversion import sw2vtkImage, sw2vtkMesh
from .plot import plot_meshes, plot_volumes, plot_meshes_volumes_mix, add_mesh_to_plotter, add_volume_to_plotter, plot_mesh_contour,plot_pca_metrics,\
pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot
pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot,dwd_plot
from .utils import num_subplots, positive_factors, save_images, get_file_with_ext, find_reference_image_index, find_reference_mesh_index, load_mesh
from .data import get_file_list, sample_images, sample_meshes
from .stats import compute_pvalues_for_group_difference,lda
from .stats import compute_pvalues_for_group_difference,lda,dwd_loadings
from .network_analysis import NetworkAnalysis
from .portal import download_dataset
from .shape_scalars import run_mbpls
25 changes: 23 additions & 2 deletions Python/shapeworks/shapeworks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d
group2_num = len(group2_map)
plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10)
plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10)

plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5)
plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5)
plt.ylabel("Probability Density")
Expand All @@ -491,4 +491,25 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d
plt.close(fig)

print("Figure saved in directory -" + lda_dir)
print()
print()

def dwd_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,dwd_dir,labels):

plt.figure(dpi=50,figsize=(14,14))
fig = plt.gcf()
plt.rcParams['font.size'] = '20'
group1_num = len(group1_map)
group2_num = len(group2_map)
plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10)
plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10)

plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5)
plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5)
plt.ylabel("Probability Density")
plt.xlabel('Shape mapping to DWD discrimination of variation between population means')
plt.legend(loc='upper right')
plt.savefig(dwd_dir+"/DWD.png")
plt.close(fig)

print("Figure saved in directory -" + dwd_dir)
print()
66 changes: 57 additions & 9 deletions Python/shapeworks/shapeworks/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,33 +71,49 @@ def compute_pvalues_for_group_difference_data(group_0_data, group_1_data, permut


def normalize(subj_map, group1_mean_map, group2_mean_map):
slope = (2.0 / (group2_mean_map - group1_mean_map))
denom = group2_mean_map - group1_mean_map
if abs(denom) < 1e-12:
return 0.0
slope = 2.0 / denom
subj_diff = subj_map - group1_mean_map
subj_map_normalized = slope * subj_diff - 1
return subj_map_normalized


def lda_loadings(group1_data, group2_data):
group1_num = np.shape(group1_data)[1]
group2_num = np.shape(group2_data)[1]

combined_data = np.concatenate((group1_data, group2_data), axis=1)
group1_mean = np.mean(group1_data, axis=1)
group2_mean = np.mean(group2_data, axis=1)

overall_mean = np.mean(combined_data, axis=1)

diffVect = group1_mean - group2_mean

return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)


def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
"""Shared logic for projecting groups onto a discriminant direction and fitting PDFs.

Args:
diffVect: Discriminant direction vector (features,)
group1_data: PCA loadings for group 1 (features x samples)
group2_data: PCA loadings for group 2 (features x samples)
combined_data: Concatenation of group1_data and group2_data (features x all_samples)

Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
"""
group1_num = group1_data.shape[1]
group2_num = group2_data.shape[1]

group1_mean = np.mean(group1_data, axis=1)
group2_mean = np.mean(group2_data, axis=1)
overall_mean = np.mean(combined_data, axis=1)

group1_mean_diff = group1_mean - overall_mean
group2_mean_diff = group2_mean - overall_mean

group1_mean_map = np.dot(diffVect, group1_mean_diff)
group2_mean_map = np.dot(diffVect, group2_mean_diff)

group1_mean_map_normalized = normalize(group1_mean_map, group1_mean_map, group2_mean_map)
group2_mean_map_normalized = normalize(group2_mean_map, group1_mean_map, group2_mean_map)

group1_map = np.zeros((group1_num,))
group2_map = np.zeros((group2_num,))

Expand All @@ -117,6 +133,13 @@ def lda_loadings(group1_data, group2_data):
group1_map_std = group1_map.std()
group2_map_std = group2_map.std()

# Guard against zero std (all samples project to same point)
min_std = 1e-6
if group1_map_std < min_std:
group1_map_std = min_std
if group2_map_std < min_std:
group2_map_std = min_std

group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)

Expand All @@ -125,6 +148,31 @@ def lda_loadings(group1_data, group2_data):
return group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map


def dwd_loadings(group1_data, group2_data):
from dwd.gen_dwd import GenDWD
group1_num = np.shape(group1_data)[1]
group2_num = np.shape(group2_data)[1]

if group1_num < 2 or group2_num < 2:
raise ValueError(f"DWD requires at least 2 samples per group (got {group1_num} and {group2_num})")

combined_data = np.concatenate((group1_data, group2_data), axis=1)

# Fit GenDWD (samples x features)
X = combined_data.T
y = np.array([1]*group1_num + [-1]*group2_num)

try:
model = GenDWD(lambd=1.0)
model.fit(X, y)
except Exception as e:
raise RuntimeError(f"DWD fitting failed: {e}") from e

diffVect = model.coef_.flatten()

return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)


def lda(data):
group_id = data["group_ids"].unique()
group1_idxs = data.index[data['group_ids'] == 0].tolist()
Expand Down
83 changes: 83 additions & 0 deletions Studio/Analysis/AnalysisTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <Job/NetworkAnalysisJob.h>
#include <Job/ParticleNormalEvaluationJob.h>
#include <Job/StatsGroupLDAJob.h>
#include <Job/StatsGroupDWDJob.h>
#include <Libs/Application/Job/PythonWorker.h>
#include <Groom/GroomParameters.h>
#include <Logging.h>
Expand Down Expand Up @@ -168,6 +169,12 @@ AnalysisTool::AnalysisTool(Preferences& prefs) : preferences_(prefs) {
connect(group_lda_job_.data(), &StatsGroupLDAJob::progress, this, &AnalysisTool::handle_lda_progress);
connect(group_lda_job_.data(), &StatsGroupLDAJob::finished, this, &AnalysisTool::handle_lda_complete);

ui_->dwd_graph->hide();
ui_->dwd_hint_label->hide();
group_dwd_job_ = QSharedPointer<StatsGroupDWDJob>::create();
connect(group_dwd_job_.data(), &StatsGroupDWDJob::progress, this, &AnalysisTool::handle_dwd_progress);
connect(group_dwd_job_.data(), &StatsGroupDWDJob::finished, this, &AnalysisTool::handle_dwd_complete);

connect(ui_->show_difference_to_mean, &QPushButton::clicked, this, &AnalysisTool::show_difference_to_mean_clicked);

connect(ui_->group_analysis_combo, qOverload<int>(&QComboBox::currentIndexChanged), this,
Expand Down Expand Up @@ -1620,6 +1627,24 @@ void AnalysisTool::update_lda_graph() {
}
}

//---------------------------------------------------------------------------
void AnalysisTool::update_dwd_graph() {
if (groups_active()) {
if (!dwd_computed_ && !group_dwd_job_running_) {
group_dwd_job_running_ = true;
ui_->dwd_label->show();
ui_->dwd_progress->setValue(0);
ui_->dwd_progress->setMaximum(0);
ui_->dwd_progress->update();
group_dwd_job_->set_stats(stats_);
app_->get_py_worker()->run_job(group_dwd_job_);
}
} else {
ui_->dwd_graph->setVisible(false);
ui_->dwd_hint_label->setVisible(false);
}
}

//---------------------------------------------------------------------------
void AnalysisTool::update_difference_particles() {
if (!stats_ready_) {
Expand Down Expand Up @@ -1679,7 +1704,10 @@ void AnalysisTool::group_changed() {
stats_ready_ = false;
group_pvalue_job_ = nullptr;
lda_computed_ = false;
dwd_computed_ = false;
compute_stats();
// Re-trigger LDA/DWD if currently visible
group_analysis_combo_changed();
}

//---------------------------------------------------------------------------
Expand Down Expand Up @@ -1909,12 +1937,64 @@ void AnalysisTool::handle_lda_complete() {
QString left_group = ui_->group_left->currentText();
QString right_group = ui_->group_right->currentText();

if (!group_lda_job_->succeeded()) {
ui_->lda_graph->setVisible(false);
if (left_group == right_group) {
ui_->lda_hint_label->setText("LDA requires two distinct groups.");
} else {
ui_->lda_hint_label->setText("LDA computation failed. Check log for details.");
}
ui_->lda_hint_label->setVisible(true);
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
return;
}

group_lda_job_->plot(ui_->lda_graph, left_group, right_group);
ui_->lda_graph->setVisible(true);
ui_->lda_hint_label->setVisible(true);
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
}

//---------------------------------------------------------------------------
void AnalysisTool::handle_dwd_progress(double progress) {
if (progress > 0) {
ui_->dwd_progress->setMaximum(100);
} else {
ui_->dwd_progress->setMaximum(0);
}
ui_->dwd_progress_widget->setVisible(progress < 1);
ui_->dwd_progress->setValue(progress * 100);
ui_->dwd_progress->update();
}

//---------------------------------------------------------------------------
void AnalysisTool::handle_dwd_complete() {
ui_->dwd_progress_widget->setVisible(false);
ui_->dwd_label->setVisible(false);
group_dwd_job_running_ = false;
dwd_computed_ = true;

QString left_group = ui_->group_left->currentText();
QString right_group = ui_->group_right->currentText();

if (!group_dwd_job_->succeeded()) {
ui_->dwd_graph->setVisible(false);
if (left_group == right_group) {
ui_->dwd_hint_label->setText("DWD requires two distinct groups.");
} else {
ui_->dwd_hint_label->setText("DWD computation failed. Check log for details.");
}
ui_->dwd_hint_label->setVisible(true);
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
return;
}

group_dwd_job_->plot(ui_->dwd_graph, left_group, right_group);
ui_->dwd_graph->setVisible(true);
ui_->dwd_hint_label->setVisible(true);
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
}

void AnalysisTool::handle_network_analysis_progress(int progress) {
if (progress > 0) {
ui_->network_progress->setMaximum(100);
Expand Down Expand Up @@ -1960,6 +2040,9 @@ void AnalysisTool::group_analysis_combo_changed() {
if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->lda_page) {
update_lda_graph();
}
if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->dwd_page) {
update_dwd_graph();
}
}
// Recalculate tab height since analysis content changed
QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current);
Expand Down
10 changes: 9 additions & 1 deletion Studio/Analysis/AnalysisTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ShapeWorksStudioApp;
class GroupPvalueJob;
class NetworkAnalysisJob;
class StatsGroupLDAJob;
class StatsGroupDWDJob;
class ParticleAreaPanel;
class ShapeScalarPanel;

Expand All @@ -37,7 +38,7 @@ class AnalysisTool : public QWidget {
public:
using AlignmentType = Analyze::AlignmentType;

enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3 };
enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3, DWD = 4 };

enum McaMode { Vanilla, Within, Between };

Expand Down Expand Up @@ -187,6 +188,9 @@ class AnalysisTool : public QWidget {
void handle_lda_progress(double progress);
void handle_lda_complete();

void handle_dwd_progress(double progress);
void handle_dwd_complete();

void handle_network_analysis_progress(int progress);
void handle_network_analysis_complete();

Expand Down Expand Up @@ -243,6 +247,7 @@ class AnalysisTool : public QWidget {
void handle_pca_group_list_item_changed();

void update_lda_graph();
void update_dwd_graph();

void update_difference_particles();

Expand Down Expand Up @@ -294,10 +299,13 @@ class AnalysisTool : public QWidget {

QSharedPointer<GroupPvalueJob> group_pvalue_job_;
QSharedPointer<StatsGroupLDAJob> group_lda_job_;
QSharedPointer<StatsGroupDWDJob> group_dwd_job_;
QSharedPointer<NetworkAnalysisJob> network_analysis_job_;

bool group_lda_job_running_ = false;
bool lda_computed_ = false;
bool group_dwd_job_running_ = false;
bool dwd_computed_ = false;
bool block_group_change_ = false;

ParticleAreaPanel* particle_area_panel_{nullptr};
Expand Down
Loading
Loading