-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_loader.py
More file actions
164 lines (135 loc) · 5.52 KB
/
dataset_loader.py
File metadata and controls
164 lines (135 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import pandas as pd
from typing import List, Dict, Optional
import json
try:
from datasets import load_dataset
DATASETS_AVAILABLE = True
except ImportError:
DATASETS_AVAILABLE = False
print("Warning: datasets package not available. Install with: pip install datasets")
class DatasetLoader:
def __init__(self, cache_dir: str = ".cache"):
"""Initialize dataset loader."""
self.cache_dir = cache_dir
def load_hotpotqa(
self,
split: str = "dev",
max_samples: Optional[int] = None,
path: Optional[str] = None
) -> List[Dict]:
"""
Load HotpotQA dataset.
Args:
split: Dataset split (train, dev, test)
max_samples: Maximum number of samples to load
path: Optional path to local file
Returns:
List of examples with 'question', 'answer', 'context', etc.
"""
if path:
# Load from local file
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data[:max_samples] if max_samples else data
if not DATASETS_AVAILABLE:
print("Warning: datasets package not available. Returning empty list.")
return []
try:
dataset = load_dataset("hotpot_qa", "fullwiki", split=split, cache_dir=self.cache_dir)
examples = []
for item in dataset:
example = {
"id": item.get("id", ""),
"question": item.get("question", ""),
"answer": item.get("answer", ""),
"context": item.get("context", {}),
"supporting_facts": item.get("supporting_facts", []),
"type": item.get("type", ""),
"level": item.get("level", "")
}
examples.append(example)
if max_samples and len(examples) >= max_samples:
break
return examples
except Exception as e:
print(f"Error loading HotpotQA: {e}")
return []
def load_fever(
self,
split: str = "dev",
max_samples: Optional[int] = None,
path: Optional[str] = None
) -> List[Dict]:
"""
Load FEVER dataset.
Args:
split: Dataset split (train, dev, test)
max_samples: Maximum number of samples to load
path: Optional path to local file
Returns:
List of examples with 'claim', 'label', 'evidence', etc.
"""
if path:
# Load from local file
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data[:max_samples] if max_samples else data
if not DATASETS_AVAILABLE:
print("Warning: datasets package not available. Returning empty list.")
return []
try:
dataset = load_dataset("fever", split=split, cache_dir=self.cache_dir)
examples = []
for item in dataset:
example = {
"id": item.get("id", 0),
"claim": item.get("claim", ""),
"label": item.get("label", ""), # SUPPORTS, REFUTES, NOT_ENOUGH_INFO
"evidence": item.get("evidence", []),
"annotated_evidence": item.get("annotated_evidence", [])
}
examples.append(example)
if max_samples and len(examples) >= max_samples:
break
return examples
except Exception as e:
print(f"Error loading FEVER: {e}")
return []
def prepare_passages_from_hotpotqa(self, examples: List[Dict]) -> List[str]:
"""
Extract passages from HotpotQA examples for indexing.
Returns:
List of passage strings
"""
passages = []
for example in examples:
context = example.get("context", {})
# HuggingFace datasets structure: context is dict with 'title' (list) and 'sentences' (list of lists)
titles = context.get("title", [])
sentences_list = context.get("sentences", [])
if len(titles) == len(sentences_list):
for title, sentences in zip(titles, sentences_list):
# sentences is a list of strings
text = " ".join(sentences)
if text:
passages.append(f"{title}: {text}")
else:
# Fallback or older format handling if needed, though standard HF is consistent
pass
return passages
def prepare_passages_from_fever(self, examples: List[Dict]) -> List[str]:
"""
Extract passages from FEVER examples for indexing.
Returns:
List of passage strings
"""
passages = []
for example in examples:
evidence = example.get("evidence", [])
for evid_group in evidence:
for evid_item in evid_group:
if isinstance(evid_item, dict):
sentence = evid_item.get("sentence", "")
if sentence:
passages.append(sentence)
return passages