Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 206 additions & 34 deletions lib/Dialect/Rotom/IR/RotomAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

Expand Down Expand Up @@ -102,47 +103,139 @@ static FailureOr<LayoutData> preprocessLayoutData(ArrayAttr dims, int64_t n,
data.pieceIndex.insert(data.pieceIndex.begin() + data.ctPrefixLen, gapIdx);
}

llvm::DenseSet<int64_t> seenDim;
Comment thread
asraa marked this conversation as resolved.
bool allUnique = true;
for (const DimAttr& d : data.traversalDims) {
if (seenDim.contains(d.getDim())) {
allUnique = false;
break;
return data;
}

static ParseResult parseDimTripleAfterLSquare(AsmParser& parser, int64_t& dim,
int64_t& size, int64_t& stride) {
return failure(parser.parseInteger(dim) || parser.parseColon() ||
parser.parseInteger(size) || parser.parseColon() ||
parser.parseInteger(stride) || parser.parseRSquare());
}

static ParseResult parseDimTriple(AsmParser& parser, int64_t& dim,
int64_t& size, int64_t& stride) {
return failure(parser.parseLSquare() ||
failed(parseDimTripleAfterLSquare(parser, dim, size, stride)));
}

static void printDimTriple(AsmPrinter& printer, DimAttr dim) {
printer << "[" << dim.getDim() << ":" << dim.getSize() << ":"
<< dim.getStride() << "]";
}

static ParseResult parseLayoutDims(AsmParser& parser,
Comment thread
asraa marked this conversation as resolved.
SmallVector<Attribute>& dims) {
if (parser.parseLSquare()) return failure();
if (succeeded(parser.parseOptionalRSquare())) return success();

while (true) {
if (succeeded(parser.parseOptionalLSquare())) {
int64_t dim;
int64_t size;
int64_t stride;
if (failed(parseDimTripleAfterLSquare(parser, dim, size, stride)))
return failure();
dims.push_back(DimAttr::get(parser.getContext(), dim, size, stride));
} else {
Attribute dim;
if (parser.parseAttribute(dim)) return failure();
if (!isa<DimAttr>(dim)) {
return parser.emitError(parser.getNameLoc())
<< "expected a #rotom.dim attribute";
}
dims.push_back(dim);
}
seenDim.insert(d.getDim());

if (succeeded(parser.parseOptionalComma())) continue;
return parser.parseRSquare();
}
if (allUnique && data.traversalDims.size() > 1) {
llvm::SmallVector<std::pair<int64_t, int64_t>> byDim;
byDim.reserve(data.traversalDims.size());
for (int64_t i = 0; i < static_cast<int64_t>(data.traversalDims.size());
++i) {
byDim.push_back({data.traversalDims[i].getDim(), i});
}
llvm::sort(byDim,
[](const auto& a, const auto& b) { return a.first < b.first; });

llvm::SmallVector<DimAttr> reorderedTraversal;
reorderedTraversal.reserve(data.traversalDims.size());
llvm::SmallVector<int64_t> oldToNew(data.traversalDims.size(), 0);
for (int64_t newIdx = 0; newIdx < static_cast<int64_t>(byDim.size());
++newIdx) {
const int64_t oldIdx = byDim[newIdx].second;
oldToNew[oldIdx] = newIdx;
reorderedTraversal.push_back(data.traversalDims[oldIdx]);
}

static ParseResult parseLayoutRolls(AsmParser& parser,
SmallVector<int64_t>& rolls) {
if (parser.parseLSquare()) return failure();
if (succeeded(parser.parseOptionalRSquare())) return success();

while (true) {
if (succeeded(parser.parseOptionalLParen())) {
int64_t from;
int64_t to;
if (parser.parseInteger(from) || parser.parseComma() ||
parser.parseInteger(to) || parser.parseRParen())
return failure();
rolls.push_back(from);
rolls.push_back(to);
} else {
int64_t value;
if (parser.parseInteger(value)) return failure();
rolls.push_back(value);
}
data.traversalDims = std::move(reorderedTraversal);

for (size_t p = 0; p < data.pieces.size(); ++p) {
if (data.pieces[p] == LayoutPieceKind::Traversal) {
data.pieceIndex[p] = oldToNew[data.pieceIndex[p]];
}
if (succeeded(parser.parseOptionalComma())) continue;
return parser.parseRSquare();
}
}

} // namespace

static LogicalResult verifyLayoutRolls(
ArrayAttr dims, DenseI64ArrayAttr rolls,
function_ref<InFlightDiagnostic()> emitError) {
if (!rolls) return success();
ArrayRef<int64_t> r = rolls.asArrayRef();
if (r.empty()) return success();
if (r.size() % 2 != 0) {
return emitError() << "rolls must contain an even number of integers "
"(pairs of dim indices)";
}

for (size_t i = 0; i < r.size(); i += 2) {
const int64_t ti = r[i];
const int64_t tj = r[i + 1];
if (ti == tj) {
return emitError() << "each roll must use two distinct dim indices";
}
if (ti < 0 || tj < 0 || ti >= static_cast<int64_t>(dims.size()) ||
tj >= static_cast<int64_t>(dims.size())) {
return emitError() << "roll dim index out of bounds for dims list";
}
auto di = dyn_cast<DimAttr>(dims[ti]);
auto dj = dyn_cast<DimAttr>(dims[tj]);
if (!di || !dj) {
return emitError() << "roll indices must refer to #rotom.dim entries";
}
if (di.getSize() != dj.getSize()) {
return emitError() << "rolled dims must have the same extent (size)";
}
if (di.isGap() || dj.isGap() || di.isReplicate() || dj.isReplicate()) {
return emitError() << "rolls may only reference non-sentinel traversal "
"dims (dim >= 0)";
}
}
return success();
}

return data;
void DimAttr::print(AsmPrinter& printer) const {
printer << "<";
printDimTriple(printer, *this);
printer << ">";
}

} // namespace
Attribute DimAttr::parse(AsmParser& parser, Type type) {
int64_t dim;
int64_t size;
int64_t stride;

if (parser.parseLess() || failed(parseDimTriple(parser, dim, size, stride)) ||
parser.parseGreater()) {
return {};
}

return DimAttr::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
parser.getContext(), dim, size, stride);
}

LogicalResult DimAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t dim, int64_t size, int64_t stride) {
Expand All @@ -158,13 +251,85 @@ LogicalResult DimAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

void LayoutAttr::print(AsmPrinter& printer) const {
printer << "<n = " << getN();

DenseI64ArrayAttr rolls = getRolls();
if (rolls && !rolls.asArrayRef().empty()) {
ArrayRef<int64_t> values = rolls.asArrayRef();
printer << ", rolls = [";
for (size_t i = 0; i < values.size(); i += 2) {
if (i != 0) printer << ", ";
printer << "(" << values[i] << ", " << values[i + 1] << ")";
}
printer << "]";
}

printer << ", dims = [";
llvm::interleaveComma(getDims(), printer, [&](Attribute attr) {
printDimTriple(printer, cast<DimAttr>(attr));
});
printer << "]>";
}

Attribute LayoutAttr::parse(AsmParser& parser, Type type) {
int64_t n;
SmallVector<int64_t> rolls;
SmallVector<Attribute> dims;

if (parser.parseLess()) return {};

if (succeeded(parser.parseOptionalKeyword("n"))) {
if (parser.parseEqual() || parser.parseInteger(n) || parser.parseComma()) {
return {};
}

if (succeeded(parser.parseOptionalKeyword("rolls"))) {
if (parser.parseEqual() || failed(parseLayoutRolls(parser, rolls)) ||
parser.parseComma()) {
return {};
}
}

if (parser.parseKeyword("dims") || parser.parseEqual() ||
failed(parseLayoutDims(parser, dims)) || parser.parseGreater()) {
return {};
}
} else if (succeeded(parser.parseOptionalKeyword("dims"))) {
if (parser.parseEqual() || failed(parseLayoutDims(parser, dims)) ||
parser.parseComma() || parser.parseKeyword("n") ||
parser.parseEqual() || parser.parseInteger(n)) {
return {};
}

if (succeeded(parser.parseOptionalComma())) {
if (parser.parseKeyword("rolls") || parser.parseEqual() ||
failed(parseLayoutRolls(parser, rolls))) {
return {};
}
}

if (parser.parseGreater()) return {};
} else {
parser.emitError(parser.getNameLoc())
<< "expected `n` or `dims` in rotom layout";
return {};
}

MLIRContext* context = parser.getContext();
return LayoutAttr::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); }, context,
ArrayAttr::get(context, dims), n, DenseI64ArrayAttr::get(context, rolls));
}

FailureOr<LayoutData> preprocessLayoutAttr(LayoutAttr layout) {
return preprocessLayoutData(layout.getDims(), layout.getN(),
layout.getContext());
}

LogicalResult LayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayAttr dims, int64_t n) {
ArrayAttr dims, int64_t n,
DenseI64ArrayAttr rolls) {
if (n <= 0) {
return emitError() << "`n` must be > 0, got " << n;
}
Expand All @@ -173,6 +338,8 @@ LogicalResult LayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "`dims` must be an array of `#rotom.dim<...>`";
}

if (failed(verifyLayoutRolls(dims, rolls, emitError))) return failure();

MLIRContext* ctx = dims.getContext();
std::vector<DimAttr> ctDims;
std::vector<DimAttr> slotDims;
Expand Down Expand Up @@ -261,12 +428,17 @@ LogicalResult SeedAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "seed layouts must be `rotom.layout` attributes";
}
if (failed(LayoutAttr::verify(emitError, layoutAttr.getDims(),
layoutAttr.getN())))
layoutAttr.getN(), layoutAttr.getRolls())))
return failure();
}
return success();
}

LayoutAttr LayoutAttr::get(MLIRContext* context, ArrayAttr dims, int64_t n) {
return get(context, dims, n,
DenseI64ArrayAttr::get(context, ArrayRef<int64_t>{}));
}

} // namespace rotom
} // namespace heir
} // namespace mlir
18 changes: 16 additions & 2 deletions lib/Dialect/Rotom/IR/RotomAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ include "mlir/IR/OpAsmInterface.td"
class Rotom_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Rotom_Dialect, name, traits> {
let mnemonic = attrMnemonic;
let assemblyFormat = "`<` struct(params) `>`";
let genMnemonicAlias = 1;
}

Expand Down Expand Up @@ -39,6 +38,7 @@ def Rotom_DimAttr : Rotom_Attr<"Dim", "dim"> {
}];

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;
}

def Rotom_LayoutAttr : Rotom_Attr<"Layout", "layout"> {
Expand All @@ -54,14 +54,27 @@ def Rotom_LayoutAttr : Rotom_Attr<"Layout", "layout"> {
For tensor_ext materialization, the **first** entry in `dims` is the
ciphertext side of Rotom's `;` split (one piece); remaining entries are
in-slot. See [Section 4.2 of the Rotom paper](https://eprint.iacr.org/2025/1319.pdf).

Optional **rolls** encode a `roll(i,j)` metadata object: each pair `(i, j)`
indexes into the `dims` array (the flattened `ct_dims + slot_dims` list) and
uses modular addition to modify the indices of `dims[i]` by the indices of
`dims[j]`.
}];

let parameters = (ins
"::mlir::ArrayAttr":$dims,
"int64_t":$n
"int64_t":$n,
OptionalParameter<"::mlir::DenseI64ArrayAttr">:$rolls
Comment thread
asraa marked this conversation as resolved.
);

let genVerifyDecl = 1;
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Layout with no `roll(i,j)` metadata (empty rolls storage).
static ::mlir::heir::rotom::LayoutAttr get(::mlir::MLIRContext *context,
::mlir::ArrayAttr dims, int64_t n);
}];
}

def Rotom_SeedAttr : Rotom_Attr<"Seed", "seed"> {
Expand All @@ -75,6 +88,7 @@ def Rotom_SeedAttr : Rotom_Attr<"Seed", "seed"> {
);

let genVerifyDecl = 1;
let assemblyFormat = "`<` struct(params) `>`";
}

#endif // LIB_DIALECT_ROTOM_IR_ROTOMATTRIBUTES_TD_
Loading
Loading