diff --git a/gem/interpreter.py b/gem/interpreter.py index 13eeb44a2..eb247ed0a 100644 --- a/gem/interpreter.py +++ b/gem/interpreter.py @@ -315,6 +315,35 @@ def _evaluate_indexsum(e, self): return Result(val.arr.sum(axis=idx), rfids) +@_evaluate.register(gem.FlexiblyIndexed) +def _evaluate_flexiblyindexed(e, self): + """Flexibly indexed first slices and then reshapes.""" + val = self(e.children[0]) + assert len(val.fids) == 0 + + idx = [] + axes = [] + for offset, idxs in e.dim2idxs: + if isinstance(offset, gem.Node): + offset = self(offset) + if len(idxs) == 0: + idx.append(offset) + continue + + indices, strides = zip(*idxs) + strides = tuple(self(stride) if isinstance(stride, gem.Node) else stride for stride in strides) + assert all(isinstance(i, gem.Index) for i in indices) + last = sum(((i.extent-1) * stride for i, stride in zip(indices, strides)), offset) + idx.append(slice(offset, last + 1)) + ndim = len(axes) + axes.extend(sorted(range(ndim, ndim + len(strides)), key=lambda i: strides[i], reverse=True)) + + fids = e.index_ordering() + shape = tuple(i.extent for i in fids) + arr = val[idx].reshape(numpy.asarray(shape)[axes]).transpose(numpy.argsort(axes)) + return Result(arr, fids) + + @_evaluate.register(gem.ListTensor) def _evaluate_listtensor(e, self): """List tensors just turn into arrays."""