-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlayers.py
More file actions
235 lines (194 loc) · 9.12 KB
/
layers.py
File metadata and controls
235 lines (194 loc) · 9.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# -*- coding: utf-8 -*-
import math
import tensorflow as tf
def masked_average_pooling(seq_emb, seq_mask):
"""
seq_emb: [N, L, D]
seq_mask: [N, L]
return: [N, D]
"""
mask = tf.cast(seq_mask, tf.float32)
mask_expanded = tf.expand_dims(mask, axis=-1) # [N, L, 1]
seq_sum = tf.reduce_sum(seq_emb * mask_expanded, axis=1) # [N, D]
valid_count = tf.reduce_sum(mask_expanded, axis=1) # [N, 1]
valid_count = tf.maximum(valid_count, 1.0)
return seq_sum / valid_count
def build_pairwise_time_gaps_from_seconds(seq_time_seconds):
"""
seq_time_seconds: [N, L]
return: [N, L, L]
"""
t1 = tf.expand_dims(seq_time_seconds, axis=2) # [N, L, 1]
t2 = tf.expand_dims(seq_time_seconds, axis=1) # [N, 1, L]
return tf.abs(t1 - t2)
def single_feature_embedding(feat_ids, vocab_size, emb_dim, scope):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
emb_table = tf.get_variable(
"emb_table",
shape=[vocab_size, emb_dim],
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer()
)
return tf.nn.embedding_lookup(emb_table, feat_ids)
def macro_temporal_encoding_personalized(hour_0_23,
weekday_0_6,
delta_t_hours,
seq_mask,
d_model,
user_context=None,
burst_value=None,
avg_brust_value=None,
scope="macro_temporal_encoding"):
"""
hour_0_23: [N, L]
weekday_0_6: [N, L]
delta_t_hours: [N, L]
seq_mask: [N, L]
user_context: [N, d_model]
burst_value: [N, L] or None
avg_brust_value: [N, L] or None
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
seq_mask = tf.cast(seq_mask, tf.float32)
hour_0_23 = tf.cast(hour_0_23, tf.float32)
weekday_0_6 = tf.cast(weekday_0_6, tf.float32)
delta_t_hours = tf.cast(delta_t_hours, tf.float32)
# recency
lambda_decay = tf.get_variable(
"lambda_decay",
shape=[],
dtype=tf.float32,
initializer=tf.constant_initializer(0.1)
)
lambda_decay = tf.nn.relu(lambda_decay)
recency = tf.exp(-lambda_decay * delta_t_hours) # [N, L]
# periodic
hour_signal = tf.sin(2.0 * math.pi * hour_0_23 / 24.0) # [N, L]
weekday_signal = tf.sin(2.0 * math.pi * weekday_0_6 / 7.0) # [N, L]
# burst
if burst_value is None:
burst_signal = tf.zeros_like(recency)
else:
prev_day_segment_stat = tf.cast(burst_value, tf.float32) # [N, L]
hist30_segment_avg_stat = tf.cast(avg_brust_value, tf.float32) # [N, L]
eps = 1e-6
delta_ratio = (prev_day_segment_stat - hist30_segment_avg_stat) / (hist30_segment_avg_stat + eps)
burst_signal = tf.log1p(tf.nn.relu(delta_ratio)) # [N, L]
h1 = tf.layers.dense(user_context, 16, activation=tf.nn.relu, name="beta_gate_dense_16")
h2 = tf.layers.dense(h1, 8, activation=tf.nn.relu, name="beta_gate_dense_8")
beta_logits = tf.layers.dense(h2, 4, activation=None, name="beta_gate_dense_4")
beta = tf.nn.softmax(beta_logits, axis=-1) # [N, 4]
beta_delta = tf.expand_dims(beta[:, 0], axis=1)
beta_hour = tf.expand_dims(beta[:, 1], axis=1)
beta_week = tf.expand_dims(beta[:, 2], axis=1)
beta_burst = tf.expand_dims(beta[:, 3], axis=1)
m_t = (beta_delta * recency + beta_hour * hour_signal + beta_week * weekday_signal + beta_burst * burst_signal)
m_t = m_t * seq_mask
beta_out = beta
# project scalar temporal signal into embedding space
r = tf.get_variable(
"temporal_projection_r",
shape=[d_model],
dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer()
)
temporal_vec = tf.expand_dims(m_t, axis=-1) * tf.reshape(r, [1, 1, d_model])
temporal_vec = temporal_vec * tf.expand_dims(seq_mask, axis=-1)
return temporal_vec, m_t, beta_out
def causal_attention(queries,
keys,
num_units=None,
num_output_units=None,
num_heads=8,
scope="causal_attention",
reuse=None,
query_masks=None,
key_masks=None,
linear_projection=True,
is_target_attention=False,
pairwise_time_gaps=None,
use_micro_bias=False,
micro_gamma=10.0,
micro_tau_init=1.0,
micro_alpha_init=0.1):
with tf.variable_scope(scope, reuse=reuse):
if num_units is None:
num_units = queries.get_shape().as_list()[-1]
query_len = queries.get_shape().as_list()[1]
key_len = keys.get_shape().as_list()[1]
if linear_projection:
queries_2d = tf.reshape(queries, [-1, queries.get_shape().as_list()[-1]])
keys_2d = tf.reshape(keys, [-1, keys.get_shape().as_list()[-1]])
Q = tf.layers.dense(queries_2d, num_units, activation=None, name="Q")
Q = tf.reshape(Q, [-1, query_len, num_units])
K = tf.layers.dense(keys_2d, num_units, activation=None, name="K")
K = tf.reshape(K, [-1, key_len, num_units])
V = tf.layers.dense(keys_2d, num_output_units or num_units, activation=None, name="V")
V = tf.reshape(V, [-1, key_len, num_output_units or num_units])
else:
Q = queries
K = keys
V = keys
if is_target_attention:
batch_size = tf.shape(Q)[0]
K = tf.reshape(K, [batch_size, key_len, K.get_shape().as_list()[-1]])
V = tf.reshape(V, [batch_size, key_len, V.get_shape().as_list()[-1]])
if num_heads > 1:
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)
else:
Q_ = Q
K_ = K
V_ = V
outputs = tf.matmul(Q_, K_, transpose_b=True)
outputs *= (K_.get_shape().as_list()[-1] ** (-0.5))
# micro bias: add before softmax
if use_micro_bias and pairwise_time_gaps is not None:
with tf.variable_scope("time_gated_order_bias", reuse=tf.AUTO_REUSE):
alpha = tf.get_variable(
"alpha",
shape=[],
dtype=tf.float32,
initializer=tf.constant_initializer(micro_alpha_init)
)
tau = tf.get_variable(
"tau",
shape=[],
dtype=tf.float32,
initializer=tf.constant_initializer(micro_tau_init)
)
alpha = tf.nn.relu(alpha)
tau = tf.nn.relu(tau)
pairwise_time_gaps = tf.cast(pairwise_time_gaps, tf.float32)
gate = tf.sigmoid((tau - pairwise_time_gaps) / micro_gamma) # [N, T_q, T_k]
q_idx = tf.range(query_len, dtype=tf.float32)
k_idx = tf.range(key_len, dtype=tf.float32)
rel_dist = tf.abs(tf.expand_dims(q_idx, 1) - tf.expand_dims(k_idx, 0)) # [T_q, T_k]
rel_dist = tf.expand_dims(rel_dist, 0)
micro_bias = -alpha * gate * rel_dist # [N, T_q, T_k]
if num_heads > 1:
micro_bias = tf.tile(micro_bias, [num_heads, 1, 1])
outputs = outputs + micro_bias
causal_mask = tf.linalg.band_part(tf.ones([query_len, key_len], dtype=tf.bool), 0, -1)
causal_mask = tf.expand_dims(causal_mask, 0)
causal_mask = tf.tile(causal_mask, [tf.shape(outputs)[0], 1, 1])
if key_masks is not None:
key_masks_3d = tf.expand_dims(key_masks, axis=1)
key_masks_3d = tf.tile(key_masks_3d, [num_heads, query_len, 1])
else:
key_masks_3d = tf.ones([tf.shape(outputs)[0], query_len, key_len], dtype=tf.bool)
if query_masks is not None:
query_masks_3d = tf.tile(tf.reshape(query_masks, [-1, query_len, 1]), [num_heads, 1, key_len])
else:
query_masks_3d = tf.ones([tf.shape(outputs)[0], query_len, key_len], dtype=tf.bool)
combined_mask = tf.logical_and(tf.logical_and(causal_mask, key_masks_3d), query_masks_3d)
paddings = tf.fill(tf.shape(outputs), tf.constant(-2**32 + 1, dtype=tf.float32))
outputs = tf.where(combined_mask, outputs, paddings)
outputs = tf.clip_by_value(outputs, -50.0, 50.0)
outputs = tf.nn.softmax(outputs)
att_vec = outputs
outputs = tf.matmul(outputs, V_)
if num_heads > 1:
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)
return outputs, att_vec