Skip to content

Commit f3d8205

Browse files
committed
Minor fixes everywhere
1 parent 6a105d8 commit f3d8205

6 files changed

Lines changed: 663 additions & 75 deletions

File tree

configs/validation.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ synthetic:
1414

1515
# RFI type counts (how many instances of each type to generate)
1616
rfi_type_counts:
17-
broadband_persistent: 4 # ~20% coverage
18-
narrowband_persistent: 10 # ~2.5% coverage
19-
narrowband_bursty: 15 # ~15% coverage
20-
broadband_bursty: 1 # ~2% coverage
21-
narrowband_intermittent: 1 # ~1% coverage
17+
broadband_persistent: 5 # ~20% coverage
18+
narrowband_persistent: 20 # ~2.5% coverage
19+
narrowband_bursty: 20 # ~15% coverage
20+
broadband_bursty: 2 # ~2% coverage
21+
narrowband_intermittent: 5 # ~1% coverage
2222
# Total: ~40% coverage
2323

2424
processing:
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Evaluate SAM-RFI on du Toit et al. (2024) HERA and LOFAR datasets.
4+
5+
Usage:
6+
python scripts/evaluate_dutoit_datasets.py \
7+
--hera-path /mnt/Data/Data/SAM-RFI/HERA_28-03-2023_all.pkl \
8+
--hera-aof-path /mnt/Data/Data/SAM-RFI/HERA_AOF_20-07-2023_all.pkl \
9+
--lofar-path /mnt/Data/Data/SAM-RFI/LOFAR_Full_RFI_dataset.pkl \
10+
--output-dir ./dutoit_evaluation
11+
"""
12+
13+
import argparse
14+
import json
15+
import pickle
16+
from pathlib import Path
17+
18+
import matplotlib.pyplot as plt
19+
import numpy as np
20+
from rfi_toolbox.evaluation import evaluate_segmentation
21+
from tqdm import tqdm
22+
23+
from samrfi.inference import RFIPredictor
24+
25+
26+
def load_dutoit_dataset(pkl_path):
27+
"""Load du Toit dataset from pickle."""
28+
with open(pkl_path, "rb") as f:
29+
data = pickle.load(f)
30+
31+
# Format: [train_images, train_masks, test_images, test_masks]
32+
return {
33+
"train_images": data[0],
34+
"train_masks": data[1],
35+
"test_images": data[2],
36+
"test_masks": data[3],
37+
}
38+
39+
40+
def evaluate_model_on_dataset(predictor, images, ground_truth, dataset_name, model_name):
41+
"""Evaluate single model on dataset - one baseline at a time to avoid memory issues."""
42+
all_metrics = []
43+
44+
print(f" Evaluating {model_name} on {dataset_name} ({len(images)} samples)...")
45+
46+
# Process one baseline at a time to avoid memory overflow
47+
for idx in tqdm(range(len(images)), desc=f" {model_name}"):
48+
img = images[idx] # Shape: (512, 512, 1 or 2)
49+
50+
# Handle different formats
51+
if img.shape[-1] == 2:
52+
# HERA format: (mag, phase) -> convert to complex
53+
magnitude = img[..., 0]
54+
phase = img[..., 1]
55+
img_complex = magnitude * np.exp(1j * phase)
56+
else:
57+
# LOFAR format: single channel magnitude
58+
img_complex = img[..., 0].astype(np.complex64)
59+
60+
# Shape: (1, 1, 512, 512) for predict_array
61+
img_4d = img_complex[np.newaxis, np.newaxis, :, :]
62+
63+
# Predict on single baseline
64+
pred = predictor.predict_array(img_4d, patch_size=1024, threshold=None)
65+
pred = pred[0, 0, :, :] # Extract (512, 512)
66+
67+
gt = ground_truth[idx][..., 0] # Remove channel dim
68+
69+
# Compute metrics
70+
metrics = evaluate_segmentation(pred, gt)
71+
all_metrics.append(metrics)
72+
73+
# Aggregate
74+
aggregated = {
75+
"iou": [m["iou"] for m in all_metrics],
76+
"precision": [m["precision"] for m in all_metrics],
77+
"recall": [m["recall"] for m in all_metrics],
78+
"f1": [m["f1"] for m in all_metrics],
79+
"dice": [m["dice"] for m in all_metrics],
80+
}
81+
82+
return aggregated
83+
84+
85+
def plot_metrics(results, output_dir):
86+
"""Generate comparison plots."""
87+
output_dir = Path(output_dir)
88+
89+
datasets = list(results.keys())
90+
models = ["tiny", "small", "base_plus", "large"]
91+
metrics = ["iou", "precision", "recall", "f1"]
92+
93+
colors = {
94+
"tiny": "tab:blue",
95+
"small": "tab:orange",
96+
"base_plus": "tab:green",
97+
"large": "tab:red",
98+
}
99+
100+
for dataset in datasets:
101+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
102+
axes = axes.flatten()
103+
104+
for idx, metric in enumerate(metrics):
105+
ax = axes[idx]
106+
107+
for model in models:
108+
if model in results[dataset]:
109+
values = results[dataset][model][metric]
110+
# mean_val = np.mean(values)
111+
# std_val = np.std(values)
112+
113+
# Box plot
114+
positions = [models.index(model)]
115+
bp = ax.boxplot(
116+
[values], positions=positions, widths=0.6, patch_artist=True, showmeans=True
117+
)
118+
bp["boxes"][0].set_facecolor(colors[model])
119+
bp["boxes"][0].set_alpha(0.6)
120+
121+
ax.set_xticks(range(len(models)))
122+
ax.set_xticklabels(models)
123+
ax.set_ylabel(metric.upper())
124+
ax.set_title(f"{metric.upper()} Distribution", fontweight="bold")
125+
ax.grid(True, alpha=0.3)
126+
127+
plt.suptitle(f"{dataset} - SAM Model Comparison", fontsize=14, fontweight="bold")
128+
plt.tight_layout()
129+
130+
output_path = output_dir / f"{dataset}_comparison.png"
131+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
132+
print(f" ✓ Saved: {output_path}")
133+
plt.close()
134+
135+
136+
def generate_summary_table(results, output_dir):
137+
"""Generate summary statistics table."""
138+
output_dir = Path(output_dir)
139+
140+
datasets = list(results.keys())
141+
models = ["tiny", "small", "base_plus", "large"]
142+
metrics = ["iou", "precision", "recall", "f1"]
143+
144+
table = []
145+
table.append("=" * 100)
146+
table.append("SAM-RFI Evaluation on du Toit et al. (2024) Datasets")
147+
table.append("=" * 100)
148+
149+
for dataset in datasets:
150+
table.append(f"\n{dataset.upper()}")
151+
table.append("-" * 100)
152+
table.append(
153+
f"{'Metric':<12} | {'tiny':<18} | {'small':<18} | {'base_plus':<18} | {'large':<18}"
154+
)
155+
table.append("-" * 100)
156+
157+
for metric in metrics:
158+
row = f"{metric.upper():<12}"
159+
for model in models:
160+
if model in results[dataset]:
161+
values = results[dataset][model][metric]
162+
mean_val = np.mean(values)
163+
std_val = np.std(values)
164+
row += f" | {mean_val:.4f} ± {std_val:.4f}"
165+
else:
166+
row += f" | {'N/A':<18}"
167+
table.append(row)
168+
169+
table.append("=" * 100)
170+
171+
table_text = "\n".join(table)
172+
print("\n" + table_text)
173+
174+
# Save to file
175+
output_path = output_dir / "summary_table.txt"
176+
with open(output_path, "w") as f:
177+
f.write(table_text)
178+
print(f"\n✓ Saved summary table: {output_path}")
179+
180+
return table_text
181+
182+
183+
def main():
184+
parser = argparse.ArgumentParser(
185+
description="Evaluate SAM-RFI on du Toit datasets",
186+
formatter_class=argparse.RawDescriptionHelpFormatter,
187+
epilog=__doc__,
188+
)
189+
190+
parser.add_argument("--hera-path", required=True, help="HERA dataset pickle")
191+
parser.add_argument("--hera-aof-path", required=True, help="HERA AOFlagger dataset pickle")
192+
parser.add_argument("--lofar-path", required=True, help="LOFAR dataset pickle")
193+
parser.add_argument("--output-dir", default="./dutoit_evaluation", help="Output directory")
194+
parser.add_argument("--device", default="cuda", help="Device (cuda/cpu)")
195+
parser.add_argument(
196+
"--use-test-set", action="store_true", help="Use test set (default: train set)"
197+
)
198+
199+
args = parser.parse_args()
200+
201+
output_dir = Path(args.output_dir)
202+
output_dir.mkdir(parents=True, exist_ok=True)
203+
204+
# Load datasets
205+
print(f"\n{'='*70}")
206+
print("Loading du Toit Datasets")
207+
print(f"{'='*70}")
208+
209+
print("Loading HERA dataset (3.1GB)...")
210+
hera = load_dutoit_dataset(args.hera_path)
211+
print(" ✓ Loaded HERA")
212+
213+
print("Loading HERA_AOF dataset (3.1GB)...")
214+
hera_aof = load_dutoit_dataset(args.hera_aof_path)
215+
print(" ✓ Loaded HERA_AOF")
216+
217+
print("Loading LOFAR dataset (9.3GB)...")
218+
lofar = load_dutoit_dataset(args.lofar_path)
219+
print(" ✓ Loaded LOFAR")
220+
221+
split = "test" if args.use_test_set else "train"
222+
print(f"Using {split} set")
223+
print(f" HERA: {len(hera[f'{split}_images'])} samples")
224+
print(f" HERA_AOF: {len(hera_aof[f'{split}_images'])} samples")
225+
print(f" LOFAR: {len(lofar[f'{split}_images'])} samples")
226+
227+
datasets = {
228+
"HERA": (hera[f"{split}_images"], hera[f"{split}_masks"]),
229+
"HERA_AOF": (hera_aof[f"{split}_images"], hera_aof[f"{split}_masks"]),
230+
"LOFAR": (lofar[f"{split}_images"], lofar[f"{split}_masks"]),
231+
}
232+
233+
# Evaluate all models
234+
models = ["tiny", "small", "base_plus", "large"]
235+
results = {dataset_name: {} for dataset_name in datasets.keys()}
236+
237+
print(f"\n{'='*70}")
238+
print("Evaluating SAM Models")
239+
print(f"{'='*70}")
240+
241+
for model_name in models:
242+
print(f"\n[{model_name.upper()}]")
243+
model_path = f"polarimetic/sam-rfi/{model_name}"
244+
245+
try:
246+
predictor = RFIPredictor(
247+
model_path=model_path, sam_checkpoint=model_name, device=args.device
248+
)
249+
250+
for dataset_name, (images, masks) in datasets.items():
251+
metrics = evaluate_model_on_dataset(
252+
predictor, images, masks, dataset_name, model_name
253+
)
254+
results[dataset_name][model_name] = metrics
255+
256+
except Exception as e:
257+
print(f" ✗ Error with {model_name}: {e}")
258+
continue
259+
260+
# Save results
261+
print(f"\n{'='*70}")
262+
print("Saving Results")
263+
print(f"{'='*70}")
264+
265+
results_path = output_dir / "results.json"
266+
with open(results_path, "w") as f:
267+
# Convert to serializable format
268+
json_results = {}
269+
for dataset, models_data in results.items():
270+
json_results[dataset] = {}
271+
for model, metrics in models_data.items():
272+
json_results[dataset][model] = {
273+
k: [float(v) for v in vals] for k, vals in metrics.items()
274+
}
275+
json.dump(json_results, f, indent=2)
276+
277+
print(f"✓ Saved metrics: {results_path}")
278+
279+
# Generate plots
280+
plot_metrics(results, output_dir)
281+
282+
# Generate summary table
283+
generate_summary_table(results, output_dir)
284+
285+
print(f"\n{'='*70}")
286+
print("✓ Evaluation Complete")
287+
print(f"{'='*70}")
288+
print(f"Results saved to: {output_dir}")
289+
print(f"{'='*70}\n")
290+
291+
292+
if __name__ == "__main__":
293+
main()

0 commit comments

Comments
 (0)