-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutils.py
More file actions
80 lines (72 loc) · 2.44 KB
/
utils.py
File metadata and controls
80 lines (72 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File : utils.py
# @Author: LauTrueYes
# @Date : 2020/12/27
from tqdm import tqdm
import torch
import time
from datetime import timedelta
from torch.utils.data import TensorDataset, DataLoader
PAD, CLS = '[PAD]', '[CLS]'
def load_dataset(file_path, config):
"""
返回结果4个list:ids, label, ids_len, mask
:param file_path:
:param seq_len:
:return:
"""
contents = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in tqdm(f):
line = line.strip()
if not line:
continue
content, original_label = line.split('\t')
# label = config.class2id[original_label]
label = int(original_label)
token = config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
pad_size = config.pad_size
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids = token_ids + ([0] * (pad_size-len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, mask, int(label)))
return contents
def build_dataset(config):
"""
返回值train,dev,test
4个list:ids, label, ids_len, mask
:param config:
:return:
"""
train = load_dataset(config.train_path, config)
dev = load_dataset(config.dev_path, config)
test = load_dataset(config.test_path, config)
return train, dev, test
def build_data_loader(dataset, config):
token_ids = [i[0] for i in dataset]
mask= [i[1] for i in dataset]
label_ids = [i[2]for i in dataset]
iter_set = TensorDataset(torch.LongTensor(token_ids).to(config.device),
torch.LongTensor(mask).to(config.device),
torch.LongTensor(label_ids).to(config.device))
iter = DataLoader(iter_set, batch_size=config.batch_size, shuffle=False)
return iter
def get_time_dif(start_time):
"""
获取已使用的时间
:param start_time:
:return:
"""
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))