-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
77 lines (59 loc) · 2.77 KB
/
train.py
File metadata and controls
77 lines (59 loc) · 2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from config import Config
from data_loader import load_formatted_batch_from_txt
from model import build_graph
def make_feed_dict(graph_dict, batch):
return {
graph_dict["labels_ph"]: batch["labels"],
graph_dict["item_ids_ph"]: batch["item_ids"],
graph_dict["seq_mask_ph"]: batch["seq_mask"],
graph_dict["hour_ph"]: batch["hour_0_23"],
graph_dict["weekday_ph"]: batch["weekday_0_6"],
graph_dict["delta_t_hours_ph"]: batch["delta_t_hours"],
graph_dict["seq_time_seconds_ph"]: batch["seq_time_seconds"],
graph_dict["prev_day_segment_stat_ph"]: batch["prev_day_segment_stat"],
graph_dict["hist30_segment_avg_stat_ph"]: batch["hist30_segment_avg_stat"],
graph_dict["user_active_degree_ph"]: batch["user_active_degree"],
graph_dict["is_lowactive_period_ph"]: batch["is_lowactive_period"],
graph_dict["is_live_streamer_ph"]: batch["is_live_streamer"],
graph_dict["is_video_author_ph"]: batch["is_video_author"],
graph_dict["follow_user_num_range_ph"]: batch["follow_user_num_range"],
graph_dict["fans_user_num_range_ph"]: batch["fans_user_num_range"],
graph_dict["register_days_range_ph"]: batch["register_days_range"],
graph_dict["onehot_feat1_ph"]: batch["onehot_feat1"],
graph_dict["onehot_feat2_ph"]: batch["onehot_feat2"],
graph_dict["onehot_feat3_ph"]: batch["onehot_feat3"],
graph_dict["target_video_id_ph"]: batch["target_video_id"],
graph_dict["target_author_id_ph"]: batch["target_author_id"],
graph_dict["target_video_type_ph"]: batch["target_video_type"],
graph_dict["target_music_type_ph"]: batch["target_music_type"]
}
def main():
config = Config()
tf.reset_default_graph()
graph_dict = build_graph(config)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
with open(config.train_file, "r") as f:
step = 0
while True:
batch = load_formatted_batch_from_txt(
file_obj=f,
batch_size=config.batch_size,
seq_len=config.seq_len
)
if batch is None:
print("End of file.")
break
feed_dict = make_feed_dict(graph_dict, batch)
loss_val, _, prob_val, beta_val = sess.run(
[graph_dict["loss"], graph_dict["train_op"], graph_dict["probs"], graph_dict["macro_beta"]],
feed_dict=feed_dict
)
if step % 100 == 0:
print("step =", step, "loss =", float(loss_val))
step += 1
if __name__ == "__main__":
main()