diff --git a/Python/shapeworks/shapeworks/__init__.py b/Python/shapeworks/shapeworks/__init__.py index 8737f9288e..c3faa384af 100644 --- a/Python/shapeworks/shapeworks/__init__.py +++ b/Python/shapeworks/shapeworks/__init__.py @@ -5,10 +5,10 @@ from shapeworks_py import * from .conversion import sw2vtkImage, sw2vtkMesh from .plot import plot_meshes, plot_volumes, plot_meshes_volumes_mix, add_mesh_to_plotter, add_volume_to_plotter, plot_mesh_contour,plot_pca_metrics,\ -pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot +pca_loadings_violinplot,plot_mode_line,visualize_reconstruction,lda_plot,dwd_plot from .utils import num_subplots, positive_factors, save_images, get_file_with_ext, find_reference_image_index, find_reference_mesh_index, load_mesh from .data import get_file_list, sample_images, sample_meshes -from .stats import compute_pvalues_for_group_difference,lda +from .stats import compute_pvalues_for_group_difference,lda,dwd_loadings from .network_analysis import NetworkAnalysis from .portal import download_dataset from .shape_scalars import run_mbpls diff --git a/Python/shapeworks/shapeworks/plot.py b/Python/shapeworks/shapeworks/plot.py index ccfa5bc26b..305253e146 100644 --- a/Python/shapeworks/shapeworks/plot.py +++ b/Python/shapeworks/shapeworks/plot.py @@ -481,7 +481,7 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d group2_num = len(group2_map) plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10) plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10) - + plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5) plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5) plt.ylabel("Probability Density") @@ -491,4 +491,25 @@ def lda_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,lda_d plt.close(fig) print("Figure saved in directory -" + lda_dir) - print() \ No newline at end of file + print() + +def dwd_plot(group1_x,group2_x,group1_pdf,group2_pdf,group1_map,group2_map,dwd_dir,labels): + + plt.figure(dpi=50,figsize=(14,14)) + fig = plt.gcf() + plt.rcParams['font.size'] = '20' + group1_num = len(group1_map) + group2_num = len(group2_map) + plt.plot(group1_x, group1_pdf, label = labels[0] + ' PDF',linewidth=10) + plt.plot(group2_x, group2_pdf, label = labels[1] + ' PDF',linewidth=10) + + plt.scatter(group1_map, 0.01*np.ones((group1_num)), s=330, label = labels[0] + ' Shape Mappings', edgecolors='black',linewidths=5) + plt.scatter(group2_map, 0.01*np.ones((group2_num)), s=330, label = labels[1] + ' Shape Mappings', edgecolors='black',linewidths=5) + plt.ylabel("Probability Density") + plt.xlabel('Shape mapping to DWD discrimination of variation between population means') + plt.legend(loc='upper right') + plt.savefig(dwd_dir+"/DWD.png") + plt.close(fig) + + print("Figure saved in directory -" + dwd_dir) + print() diff --git a/Python/shapeworks/shapeworks/stats.py b/Python/shapeworks/shapeworks/stats.py index 12bdd18c8e..dbc71be727 100644 --- a/Python/shapeworks/shapeworks/stats.py +++ b/Python/shapeworks/shapeworks/stats.py @@ -71,33 +71,49 @@ def compute_pvalues_for_group_difference_data(group_0_data, group_1_data, permut def normalize(subj_map, group1_mean_map, group2_mean_map): - slope = (2.0 / (group2_mean_map - group1_mean_map)) + denom = group2_mean_map - group1_mean_map + if abs(denom) < 1e-12: + return 0.0 + slope = 2.0 / denom subj_diff = subj_map - group1_mean_map subj_map_normalized = slope * subj_diff - 1 return subj_map_normalized def lda_loadings(group1_data, group2_data): - group1_num = np.shape(group1_data)[1] - group2_num = np.shape(group2_data)[1] - combined_data = np.concatenate((group1_data, group2_data), axis=1) group1_mean = np.mean(group1_data, axis=1) group2_mean = np.mean(group2_data, axis=1) - overall_mean = np.mean(combined_data, axis=1) - diffVect = group1_mean - group2_mean + return _project_and_pdf(diffVect, group1_data, group2_data, combined_data) + + +def _project_and_pdf(diffVect, group1_data, group2_data, combined_data): + """Shared logic for projecting groups onto a discriminant direction and fitting PDFs. + + Args: + diffVect: Discriminant direction vector (features,) + group1_data: PCA loadings for group 1 (features x samples) + group2_data: PCA loadings for group 2 (features x samples) + combined_data: Concatenation of group1_data and group2_data (features x all_samples) + + Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map) + """ + group1_num = group1_data.shape[1] + group2_num = group2_data.shape[1] + + group1_mean = np.mean(group1_data, axis=1) + group2_mean = np.mean(group2_data, axis=1) + overall_mean = np.mean(combined_data, axis=1) + group1_mean_diff = group1_mean - overall_mean group2_mean_diff = group2_mean - overall_mean group1_mean_map = np.dot(diffVect, group1_mean_diff) group2_mean_map = np.dot(diffVect, group2_mean_diff) - group1_mean_map_normalized = normalize(group1_mean_map, group1_mean_map, group2_mean_map) - group2_mean_map_normalized = normalize(group2_mean_map, group1_mean_map, group2_mean_map) - group1_map = np.zeros((group1_num,)) group2_map = np.zeros((group2_num,)) @@ -117,6 +133,13 @@ def lda_loadings(group1_data, group2_data): group1_map_std = group1_map.std() group2_map_std = group2_map.std() + # Guard against zero std (all samples project to same point) + min_std = 1e-6 + if group1_map_std < min_std: + group1_map_std = min_std + if group2_map_std < min_std: + group2_map_std = min_std + group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300) group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300) @@ -125,6 +148,31 @@ def lda_loadings(group1_data, group2_data): return group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map +def dwd_loadings(group1_data, group2_data): + from dwd.gen_dwd import GenDWD + group1_num = np.shape(group1_data)[1] + group2_num = np.shape(group2_data)[1] + + if group1_num < 2 or group2_num < 2: + raise ValueError(f"DWD requires at least 2 samples per group (got {group1_num} and {group2_num})") + + combined_data = np.concatenate((group1_data, group2_data), axis=1) + + # Fit GenDWD (samples x features) + X = combined_data.T + y = np.array([1]*group1_num + [-1]*group2_num) + + try: + model = GenDWD(lambd=1.0) + model.fit(X, y) + except Exception as e: + raise RuntimeError(f"DWD fitting failed: {e}") from e + + diffVect = model.coef_.flatten() + + return _project_and_pdf(diffVect, group1_data, group2_data, combined_data) + + def lda(data): group_id = data["group_ids"].unique() group1_idxs = data.index[data['group_ids'] == 0].tolist() diff --git a/Studio/Analysis/AnalysisTool.cpp b/Studio/Analysis/AnalysisTool.cpp index da70b3ed7b..7a5e08541c 100644 --- a/Studio/Analysis/AnalysisTool.cpp +++ b/Studio/Analysis/AnalysisTool.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -168,6 +169,12 @@ AnalysisTool::AnalysisTool(Preferences& prefs) : preferences_(prefs) { connect(group_lda_job_.data(), &StatsGroupLDAJob::progress, this, &AnalysisTool::handle_lda_progress); connect(group_lda_job_.data(), &StatsGroupLDAJob::finished, this, &AnalysisTool::handle_lda_complete); + ui_->dwd_graph->hide(); + ui_->dwd_hint_label->hide(); + group_dwd_job_ = QSharedPointer::create(); + connect(group_dwd_job_.data(), &StatsGroupDWDJob::progress, this, &AnalysisTool::handle_dwd_progress); + connect(group_dwd_job_.data(), &StatsGroupDWDJob::finished, this, &AnalysisTool::handle_dwd_complete); + connect(ui_->show_difference_to_mean, &QPushButton::clicked, this, &AnalysisTool::show_difference_to_mean_clicked); connect(ui_->group_analysis_combo, qOverload(&QComboBox::currentIndexChanged), this, @@ -1620,6 +1627,24 @@ void AnalysisTool::update_lda_graph() { } } +//--------------------------------------------------------------------------- +void AnalysisTool::update_dwd_graph() { + if (groups_active()) { + if (!dwd_computed_ && !group_dwd_job_running_) { + group_dwd_job_running_ = true; + ui_->dwd_label->show(); + ui_->dwd_progress->setValue(0); + ui_->dwd_progress->setMaximum(0); + ui_->dwd_progress->update(); + group_dwd_job_->set_stats(stats_); + app_->get_py_worker()->run_job(group_dwd_job_); + } + } else { + ui_->dwd_graph->setVisible(false); + ui_->dwd_hint_label->setVisible(false); + } +} + //--------------------------------------------------------------------------- void AnalysisTool::update_difference_particles() { if (!stats_ready_) { @@ -1679,7 +1704,10 @@ void AnalysisTool::group_changed() { stats_ready_ = false; group_pvalue_job_ = nullptr; lda_computed_ = false; + dwd_computed_ = false; compute_stats(); + // Re-trigger LDA/DWD if currently visible + group_analysis_combo_changed(); } //--------------------------------------------------------------------------- @@ -1909,12 +1937,64 @@ void AnalysisTool::handle_lda_complete() { QString left_group = ui_->group_left->currentText(); QString right_group = ui_->group_right->currentText(); + if (!group_lda_job_->succeeded()) { + ui_->lda_graph->setVisible(false); + if (left_group == right_group) { + ui_->lda_hint_label->setText("LDA requires two distinct groups."); + } else { + ui_->lda_hint_label->setText("LDA computation failed. Check log for details."); + } + ui_->lda_hint_label->setVisible(true); + QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current); + return; + } + group_lda_job_->plot(ui_->lda_graph, left_group, right_group); ui_->lda_graph->setVisible(true); ui_->lda_hint_label->setVisible(true); QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current); } +//--------------------------------------------------------------------------- +void AnalysisTool::handle_dwd_progress(double progress) { + if (progress > 0) { + ui_->dwd_progress->setMaximum(100); + } else { + ui_->dwd_progress->setMaximum(0); + } + ui_->dwd_progress_widget->setVisible(progress < 1); + ui_->dwd_progress->setValue(progress * 100); + ui_->dwd_progress->update(); +} + +//--------------------------------------------------------------------------- +void AnalysisTool::handle_dwd_complete() { + ui_->dwd_progress_widget->setVisible(false); + ui_->dwd_label->setVisible(false); + group_dwd_job_running_ = false; + dwd_computed_ = true; + + QString left_group = ui_->group_left->currentText(); + QString right_group = ui_->group_right->currentText(); + + if (!group_dwd_job_->succeeded()) { + ui_->dwd_graph->setVisible(false); + if (left_group == right_group) { + ui_->dwd_hint_label->setText("DWD requires two distinct groups."); + } else { + ui_->dwd_hint_label->setText("DWD computation failed. Check log for details."); + } + ui_->dwd_hint_label->setVisible(true); + QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current); + return; + } + + group_dwd_job_->plot(ui_->dwd_graph, left_group, right_group); + ui_->dwd_graph->setVisible(true); + ui_->dwd_hint_label->setVisible(true); + QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current); +} + void AnalysisTool::handle_network_analysis_progress(int progress) { if (progress > 0) { ui_->network_progress->setMaximum(100); @@ -1960,6 +2040,9 @@ void AnalysisTool::group_analysis_combo_changed() { if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->lda_page) { update_lda_graph(); } + if (ui_->group_analysis_stacked_widget->currentWidget() == ui_->dwd_page) { + update_dwd_graph(); + } } // Recalculate tab height since analysis content changed QTimer::singleShot(0, this, &AnalysisTool::resize_tab_to_current); diff --git a/Studio/Analysis/AnalysisTool.h b/Studio/Analysis/AnalysisTool.h index c6f0dd8511..69ea107644 100644 --- a/Studio/Analysis/AnalysisTool.h +++ b/Studio/Analysis/AnalysisTool.h @@ -28,6 +28,7 @@ class ShapeWorksStudioApp; class GroupPvalueJob; class NetworkAnalysisJob; class StatsGroupLDAJob; +class StatsGroupDWDJob; class ParticleAreaPanel; class ShapeScalarPanel; @@ -37,7 +38,7 @@ class AnalysisTool : public QWidget { public: using AlignmentType = Analyze::AlignmentType; - enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3 }; + enum GroupAnalysisType { None = 0, Pvalues = 1, NetworkAnalysis = 2, LDA = 3, DWD = 4 }; enum McaMode { Vanilla, Within, Between }; @@ -187,6 +188,9 @@ class AnalysisTool : public QWidget { void handle_lda_progress(double progress); void handle_lda_complete(); + void handle_dwd_progress(double progress); + void handle_dwd_complete(); + void handle_network_analysis_progress(int progress); void handle_network_analysis_complete(); @@ -243,6 +247,7 @@ class AnalysisTool : public QWidget { void handle_pca_group_list_item_changed(); void update_lda_graph(); + void update_dwd_graph(); void update_difference_particles(); @@ -294,10 +299,13 @@ class AnalysisTool : public QWidget { QSharedPointer group_pvalue_job_; QSharedPointer group_lda_job_; + QSharedPointer group_dwd_job_; QSharedPointer network_analysis_job_; bool group_lda_job_running_ = false; bool lda_computed_ = false; + bool group_dwd_job_running_ = false; + bool dwd_computed_ = false; bool block_group_change_ = false; ParticleAreaPanel* particle_area_panel_{nullptr}; diff --git a/Studio/Analysis/AnalysisTool.ui b/Studio/Analysis/AnalysisTool.ui index cb2237ba5a..c3f3d6c62b 100644 --- a/Studio/Analysis/AnalysisTool.ui +++ b/Studio/Analysis/AnalysisTool.ui @@ -891,6 +891,77 @@ QWidget#particles_panel { + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + + 0 + 0 + + + + + 50 + 50 + + + + + + + + (Resize or pop-out panel for a larger graph) + + + Qt::AlignCenter + + + true + + + + + + + + + + + 0 + 0 + + + + 0 + + + + + + + Computing DWD... + + + + + + + + @@ -922,6 +993,11 @@ QWidget#particles_panel { LDA + + + DWD + + diff --git a/Studio/CMakeLists.txt b/Studio/CMakeLists.txt index 44e1bb29e0..7257d13b70 100644 --- a/Studio/CMakeLists.txt +++ b/Studio/CMakeLists.txt @@ -107,6 +107,7 @@ SET(STUDIO_JOB_SRCS Job/NetworkAnalysisJob.cpp Job/ParticleNormalEvaluationJob.cpp Job/StatsGroupLDAJob.cpp + Job/StatsGroupDWDJob.cpp Job/ShapeScalarJob.cpp ) @@ -116,6 +117,7 @@ SET(STUDIO_JOB_MOC_HDRS Job/ParticleAreaJob.h Job/ParticleNormalEvaluationJob.h Job/StatsGroupLDAJob.h + Job/StatsGroupDWDJob.h Job/ShapeScalarJob.h ) diff --git a/Studio/Job/StatsGroupDWDJob.cpp b/Studio/Job/StatsGroupDWDJob.cpp new file mode 100644 index 0000000000..d7f9911180 --- /dev/null +++ b/Studio/Job/StatsGroupDWDJob.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +namespace py = pybind11; +using namespace pybind11::literals; // to bring in the `_a` literal + +#include +#include +#include +#include + +namespace shapeworks { + +//--------------------------------------------------------------------------- +StatsGroupDWDJob::StatsGroupDWDJob() {} + +//--------------------------------------------------------------------------- +void StatsGroupDWDJob::set_stats(ParticleShapeStatistics stats) { stats_ = stats; } + +//--------------------------------------------------------------------------- +void StatsGroupDWDJob::run() { + succeeded_ = false; + Q_EMIT progress(0.1); + stats_.principal_component_projections(); + auto pca_loadings = stats_.get_pca_loadings(); + Q_EMIT progress(0.2); + + auto& group_ids = stats_.GroupID(); + + int num_samples = pca_loadings.rows(); + + Eigen::MatrixXd group_1_data; + Eigen::MatrixXd group_2_data; + + int group_1_count = std::count(group_ids.begin(), group_ids.end(), 1); + int group_2_count = num_samples - group_1_count; + if (group_1_count == 0 || group_2_count == 0) { + return; + } + + group_1_data.resize(group_1_count, pca_loadings.cols()); + group_2_data.resize(group_2_count, pca_loadings.cols()); + + int group_1_idx = 0; + int group_2_idx = 0; + for (int i = 0; i < num_samples; i++) { + if (group_ids[i] == 1) { + group_1_data.row(group_1_idx++) = pca_loadings.row(i); + } else { + group_2_data.row(group_2_idx++) = pca_loadings.row(i); + } + } + + try { + py::module sw = py::module::import("shapeworks"); + py::object dwd_loadings = sw.attr("stats").attr("dwd_loadings"); + Q_EMIT progress(0.5); + + using ResultType = + std::tuple; + ResultType result = dwd_loadings(group_1_data.transpose(), group_2_data.transpose()).cast(); + + group1_x_ = std::get<0>(result); + group2_x_ = std::get<1>(result); + group1_pdf_ = std::get<2>(result); + group2_pdf_ = std::get<3>(result); + group1_map_ = std::get<4>(result); + group2_map_ = std::get<5>(result); + } catch (const std::exception& e) { + SW_ERROR("DWD computation failed: {}", e.what()); + succeeded_ = false; + return; + } + + succeeded_ = true; + Q_EMIT progress(1.0); +} + +//--------------------------------------------------------------------------- +QString StatsGroupDWDJob::name() { return "Group DWD"; } + +//--------------------------------------------------------------------------- +void StatsGroupDWDJob::plot(JKQTPlotter* plot, QString group_1_name, QString group_2_name) { + JKQTPDatastore* ds = plot->getDatastore(); + ds->clear(); + plot->clearGraphs(); + + QString title = "DWD"; + + auto draw_line_plot = [&](Eigen::MatrixXd x, Eigen::MatrixXd y, QString name, QColor color) { + QVector xv, yv; + for (int i = 0; i < x.size(); i++) { + xv << x(i); + yv << y(i); + } + + QString x_label = name + " PDF"; + QString y_label = name + " y"; + + size_t column_x = ds->addCopiedColumn(xv, x_label); + size_t column_y = ds->addCopiedColumn(yv, y_label); + + JKQTPXYLineGraph* graph = new JKQTPXYLineGraph(plot); + graph->setColor(color); + graph->setSymbolType(JKQTPNoSymbol); + graph->setXColumn(column_x); + graph->setYColumn(column_y); + graph->setTitle(name + " PDF"); + plot->addGraph(graph); + }; + + draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54)); + draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue); + + auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) { + QVector x, y; + for (int i = 0; i < map.size(); i++) { + x << map(i); + y << 0.01; + } + + int column_x = ds->addCopiedColumn(x, name + "scatter x"); + int column_y = ds->addCopiedColumn(y, name + "scatter y"); + + auto scatter = new JKQTPXYParametrizedScatterGraph(plot); + scatter->setColor(color); + scatter->setXColumn(column_x); + scatter->setYColumn(column_y); + scatter->setTitle(name + " Shape Mappings"); + plot->addGraph(scatter); + }; + + draw_scatter_plot(group1_map_, group_1_name, QColor(239, 133, 54)); + draw_scatter_plot(group2_map_, group_2_name, Qt::blue); + + plot->getPlotter()->setUseAntiAliasingForGraphs(true); + plot->getPlotter()->setUseAntiAliasingForSystem(true); + plot->getPlotter()->setUseAntiAliasingForText(true); + plot->getPlotter()->setPlotLabelFontSize(18); + plot->getPlotter()->setPlotLabel("\\textbf{" + title + "}"); + plot->getPlotter()->setDefaultTextSize(14); + plot->getPlotter()->setShowKey(true); + + plot->getXAxis()->setAxisLabel("Shape mapping to DWD discrimination of variation between population means"); + plot->getXAxis()->setLabelFontSize(8); + plot->getYAxis()->setAxisLabel("Probability Density"); + plot->getYAxis()->setLabelFontSize(14); + + plot->clearAllMouseWheelActions(); + plot->setMousePositionShown(false); + plot->setMinimumSize(250, 250); + plot->zoomToFit(); +} +} // namespace shapeworks diff --git a/Studio/Job/StatsGroupDWDJob.h b/Studio/Job/StatsGroupDWDJob.h new file mode 100644 index 0000000000..647fe88e39 --- /dev/null +++ b/Studio/Job/StatsGroupDWDJob.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include + +class JKQTPlotter; + +namespace shapeworks { + +class StatsGroupDWDJob : public Job { + Q_OBJECT + public: + StatsGroupDWDJob(); + + void set_stats(ParticleShapeStatistics stats); + + void run() override; + + QString name() override; + + void plot(JKQTPlotter* plot, QString group_1_name, QString group_2_name); + + bool succeeded() const { return succeeded_; } + + private: + bool succeeded_ = false; + ParticleShapeStatistics stats_; + Eigen::MatrixXd group1_x_, group2_x_, group1_pdf_, group2_pdf_, group1_map_, group2_map_; +}; +} // namespace shapeworks diff --git a/Studio/Job/StatsGroupLDAJob.cpp b/Studio/Job/StatsGroupLDAJob.cpp index 8fdb48f2fc..fa709059c1 100644 --- a/Studio/Job/StatsGroupLDAJob.cpp +++ b/Studio/Job/StatsGroupLDAJob.cpp @@ -5,6 +5,7 @@ namespace py = pybind11; using namespace pybind11::literals; // to bring in the `_a` literal #include +#include #include #include @@ -18,8 +19,9 @@ void StatsGroupLDAJob::set_stats(ParticleShapeStatistics stats) { stats_ = stats //--------------------------------------------------------------------------- void StatsGroupLDAJob::run() { + succeeded_ = false; Q_EMIT progress(0.1); - stats_.principal_component_projections(); + stats_.principal_component_projections(); auto pca_loadings = stats_.get_pca_loadings(); Q_EMIT progress(0.2); @@ -49,21 +51,27 @@ void StatsGroupLDAJob::run() { } } - py::module sw = py::module::import("shapeworks"); - py::object lda_loadings = sw.attr("stats").attr("lda_loadings"); - Q_EMIT progress(0.5); - - using ResultType = - std::tuple; - ResultType result = lda_loadings(group_1_data.transpose(), group_2_data.transpose()).cast(); - - group1_x_ = std::get<0>(result); - group2_x_ = std::get<1>(result); - group1_pdf_ = std::get<2>(result); - group2_pdf_ = std::get<3>(result); - group1_map_ = std::get<4>(result); - group2_map_ = std::get<5>(result); + try { + py::module sw = py::module::import("shapeworks"); + py::object lda_loadings = sw.attr("stats").attr("lda_loadings"); + Q_EMIT progress(0.5); + + using ResultType = + std::tuple; + ResultType result = lda_loadings(group_1_data.transpose(), group_2_data.transpose()).cast(); + + group1_x_ = std::get<0>(result); + group2_x_ = std::get<1>(result); + group1_pdf_ = std::get<2>(result); + group2_pdf_ = std::get<3>(result); + group1_map_ = std::get<4>(result); + group2_map_ = std::get<5>(result); + } catch (const std::exception& e) { + SW_ERROR("LDA computation failed: {}", e.what()); + return; + } + succeeded_ = true; Q_EMIT progress(1.0); } diff --git a/Studio/Job/StatsGroupLDAJob.h b/Studio/Job/StatsGroupLDAJob.h index 4a5bc0f9b7..6393e2c026 100644 --- a/Studio/Job/StatsGroupLDAJob.h +++ b/Studio/Job/StatsGroupLDAJob.h @@ -21,7 +21,10 @@ class StatsGroupLDAJob : public Job { void plot(JKQTPlotter* plot, QString group_1_name, QString group_2_name); + bool succeeded() const { return succeeded_; } + private: + bool succeeded_ = false; ParticleShapeStatistics stats_; Eigen::MatrixXd group1_x_, group2_x_, group1_pdf_, group2_pdf_, group1_map_, group2_map_; }; diff --git a/python_requirements.txt b/python_requirements.txt index dad41a764f..e89b2c0444 100644 --- a/python_requirements.txt +++ b/python_requirements.txt @@ -130,6 +130,8 @@ docopt==0.6.2 # via # -r requirements.in # grip +dwd==1.0.5 + # via -r requirements.in entrypoints==0.4 # via -r requirements.in et-xmlfile==2.0.0