Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions gem/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,35 @@ def _evaluate_indexsum(e, self):
return Result(val.arr.sum(axis=idx), rfids)


@_evaluate.register(gem.FlexiblyIndexed)
def _evaluate_flexiblyindexed(e, self):
"""Flexibly indexed first slices and then reshapes."""
val = self(e.children[0])
assert len(val.fids) == 0

idx = []
axes = []
for offset, idxs in e.dim2idxs:
if isinstance(offset, gem.Node):
offset = self(offset)
if len(idxs) == 0:
idx.append(offset)
continue

indices, strides = zip(*idxs)
strides = tuple(self(stride) if isinstance(stride, gem.Node) else stride for stride in strides)
assert all(isinstance(i, gem.Index) for i in indices)
last = sum(((i.extent-1) * stride for i, stride in zip(indices, strides)), offset)
idx.append(slice(offset, last + 1))
ndim = len(axes)
axes.extend(sorted(range(ndim, ndim + len(strides)), key=lambda i: strides[i], reverse=True))

fids = e.index_ordering()
shape = tuple(i.extent for i in fids)
arr = val[idx].reshape(numpy.asarray(shape)[axes]).transpose(numpy.argsort(axes))
return Result(arr, fids)


@_evaluate.register(gem.ListTensor)
def _evaluate_listtensor(e, self):
"""List tensors just turn into arrays."""
Expand Down
Loading