-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgather.py
More file actions
23 lines (16 loc) · 856 Bytes
/
gather.py
File metadata and controls
23 lines (16 loc) · 856 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
def topk_by(values, vdim, keys, kdim, k):
indices = keys.topk(k=k, dim=kdim, sorted=True).indices
indices = indices.unsqueeze(-1).expand(indices.size() + values.size()[vdim+1:])
values_topk = values.gather(dim=vdim, index=indices)
return values_topk
def topk_and_index_by(values, vdim, keys, kdim, k):
indices = keys.topk(k=k, dim=kdim, sorted=True).indices
indices = indices.unsqueeze(-1).expand(indices.size() + values.size()[vdim+1:])
values_topk = values.gather(dim=vdim, index=indices)
return values_topk, indices
if __name__ == '__main__':
a = torch.tensor([[[ 1.0, 0.0, 3.0 ], [0.0, 0.0, 2.0], [ 3.0, 3.0, 3.0], [4.0, 4.0, 4.0]],
[[-1.0, 0.3, 0.23], [1.0, 0.0, -2.0], [-3.0, -0.3, -3.0], [0.44, -0.44, 4.04]]])
keys = a.sum(dim=2)
topk_by(a, 1, keys, 1, 2)