-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
131 lines (108 loc) · 3.84 KB
/
utils.py
File metadata and controls
131 lines (108 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# _*_ coding: UTF-8 _*_
# Author LBK
import random
from tqdm import tqdm
import torch
import time
from datetime import timedelta
import pickle as pkl
import os
PAD, CLS = '[PAD]', '[CLS]'
def load_dataset(file_path, config):
"""
返回结果 4个list: ids, labels, ids_lens, mask
:param file_path:
:param config:
:return:
"""
contents = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in tqdm(f):
line = line.strip()
if not line:
continue
content, label = line.split('\t')
token = config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
pad_size = config.pad_size
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids = token_ids + ([0]*(pad_size-len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, int(label), seq_len, mask))
random.shuffle(contents)
return contents
def build_dataset(config):
"""
返回值 train, dev, test
4个list: ids, labels, ids_len, mask
:param config:
:return:
"""
if os.path.exists(config.datasetpkl):
dataset = pkl.load(open(config.datasetpkl, 'rb'))
train = dataset['train']
dev = dataset['dev']
test = dataset['test']
else:
train = load_dataset(config.train_path, config)
dev = load_dataset(config.dev_path, config)
test = load_dataset(config.test_path, config)
dataset = {}
dataset['train'] = train
dataset['dev'] = dev
dataset['test'] = test
pkl.dump(dataset, open(config.datasetpkl, 'wb'))
return train, dev, test
class DatasetIterator(object):
def __init__(self, dataset, batch_size, device):
self.dataset = dataset
self.batch_size = batch_size
self.index = 0
self.device = device
self.n_batches = len(dataset) // batch_size
self.residue = False
if len(dataset) % self.n_batches != 0:
self.residue = True
def _to_tensor(self, datas):
x = torch.LongTensor([item[0] for item in datas]).to(self.device) # 样本数据ids
y = torch.LongTensor([item[1] for item in datas]).to(self.device) # 标签数据label
seq_len = torch.LongTensor([item[2] for item in datas]).to(self.device) # 每一个序列的真实长度
mask = torch.LongTensor([item[3] for item in datas]).to(self.device)
return (x, seq_len, mask), y
def __next__(self):
if self.residue and self.index == self.n_batches:
batches = self.dataset[self.index*self.batch_size: len(self.dataset)]
self.index += 1
batches = self._to_tensor(batches)
return batches
elif self.index > self.n_batches:
self.index = 0
raise StopIteration
else:
batches = self.dataset[self.index * self.batch_size: (self.index+1) * self.batch_size]
self.index += 1
batches = self._to_tensor(batches)
return batches
def __iter__(self):
return self
def __len__(self):
if self.residue:
return self.n_batches + 1
else:
return self.n_batches
def build_iterator(dataset, config):
iter = DatasetIterator(dataset, config.batch_size, config.device)
return iter
def get_time_dif(start_time):
"""获取使用时间"""
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))