-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathheatmap_visualizer.py
More file actions
173 lines (149 loc) · 5.75 KB
/
heatmap_visualizer.py
File metadata and controls
173 lines (149 loc) · 5.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from typing import Optional, List
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
# Set random seed for deterministic clustering
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
def create_similarity_heatmap(
analysis_df: pd.DataFrame,
metric: str = 'semantic_z_score',
title: Optional[str] = None,
colorscale: str = 'Electric',
width: int = 800,
height: int = 600,
zmin: Optional[float] = 0,
zmax: Optional[float] = 1.0
) -> go.Figure:
"""
Create a heatmap visualization of similarity metrics between article pairs.
Args:
results_df: DataFrame with pairwise analysis results
metric: Column name for the similarity metric to visualize
title: Custom title for the heatmap
colorscale: Plotly color scale name
width: Figure width in pixels
height: Figure height in pixels
show_percentiles: Whether to show percentile values in hover text
include_baseline_stats: Whether to include baseline statistics in hover
threshold: Minimum value to display (values below this become transparent)
Returns:
Plotly Figure object
"""
# Extract unique article names from the 'pair' column
pairs = analysis_df['pair'].tolist()
articles = set()
for pair in pairs:
# Split pair string (e.g., "Article1 vs Article2")
if ' vs ' in pair:
article1, article2 = pair.split(' vs ')
articles.add(article1.strip())
articles.add(article2.strip())
else:
# Handle different pair formats
articles.add(pair.strip())
articles = sorted(list(articles))
n_articles = len(articles)
# Create similarity matrix
similarity_matrix = np.full((n_articles, n_articles), np.nan)
# Fill the matrix with similarity values
for _, row in analysis_df.iterrows():
pair = row['pair']
if ' vs ' in pair:
article1, article2 = pair.split(' vs ')
article1 = article1.strip()
article2 = article2.strip()
i = articles.index(article1)
j = articles.index(article2)
# Fill both directions (symmetric matrix)
similarity_matrix[i, j] = row[f'{metric}_avg_score']
similarity_matrix[j, i] = row[f'{metric}_avg_score']
# # Set diagonal to 1.0 (self-similarity)
np.fill_diagonal(similarity_matrix, 1.0)
# hierarchical clustering
# Calculate distance matrix
distance_matrix = pdist(similarity_matrix)
# Perform hierarchical clustering with deterministic behavior
linkage_matrix = linkage(distance_matrix, method='ward')
clusters = fcluster(linkage_matrix, t=0, criterion="inconsistent")
idx = np.argsort(clusters)
# reorder the similarity matrix
similarity_matrix = similarity_matrix[idx, :][:, idx]
articles = [articles[i] for i in idx]
# Create hover text matrix
hover_text = []
for i in range(n_articles):
row_text = []
for j in range(n_articles):
if i == j:
text = f"<b>{articles[i]}</b><br>Self-similarity"
else:
sim_val = similarity_matrix[i, j]
# Use the reordered articles to create the pair key
article_i = articles[i]
article_j = articles[j]
# Try both directions for the pair
pair_key_forward = f"{article_i} vs {article_j}"
pair_key_reverse = f"{article_j} vs {article_i}"
for pair_key in [pair_key_forward, pair_key_reverse]:
matching_rows = (
analysis_df.loc[analysis_df['pair'] == pair_key]
)
if len(matching_rows) > 0:
row = matching_rows.iloc[0]
metric_title = metric.replace('_', ' ').title()
text = (f"<b>{article_i} vs {article_j}</b><br>{metric_title}: "
f"{sim_val:.4f}")
text += f"<br>median: {row[f'{metric}_median_score']:.4f}"
text += f"<br>z-score: {row[f'{metric}_z_score']:.3f} "
text += (f"<br>baseline mean: {row[f'{metric}_baseline_mean']:.4f} "
f"± {row[f'{metric}_baseline_std']:.4f}")
break
else:
text = f"<b>{article_i} vs {article_j}</b><br>No data"
row_text.append(text)
hover_text.append(row_text)
# Create the heatmap
fig = go.Figure(data=go.Heatmap(
z=similarity_matrix,
x=articles,
y=articles,
text=hover_text,
hoverinfo='text',
colorscale=colorscale,
zmin=zmin,
zmax=zmax,
showscale=True,
colorbar=dict(
title=metric.replace('_', ' ').title()
)
))
# Update layout
if title is None:
title = f"Similarity Heatmap: {metric.replace('_', ' ').title()}"
fig.update_layout(
title=dict(
text=title,
x=0.5,
font=dict(size=16)
),
width=width,
height=height,
xaxis=dict(
# title="Articles",
tickangle=45,
tickmode='array',
ticktext=articles,
tickvals=list(range(len(articles)))
),
yaxis=dict(
# title="Articles",
tickmode='array',
ticktext=articles,
tickvals=list(range(len(articles)))
),
hovermode='closest'
)
return fig