diff --git a/lib/Dialect/Rotom/IR/RotomAttributes.cpp b/lib/Dialect/Rotom/IR/RotomAttributes.cpp index 1129a3ce49..6703485d35 100644 --- a/lib/Dialect/Rotom/IR/RotomAttributes.cpp +++ b/lib/Dialect/Rotom/IR/RotomAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project @@ -102,47 +103,139 @@ static FailureOr preprocessLayoutData(ArrayAttr dims, int64_t n, data.pieceIndex.insert(data.pieceIndex.begin() + data.ctPrefixLen, gapIdx); } - llvm::DenseSet seenDim; - bool allUnique = true; - for (const DimAttr& d : data.traversalDims) { - if (seenDim.contains(d.getDim())) { - allUnique = false; - break; + return data; +} + +static ParseResult parseDimTripleAfterLSquare(AsmParser& parser, int64_t& dim, + int64_t& size, int64_t& stride) { + return failure(parser.parseInteger(dim) || parser.parseColon() || + parser.parseInteger(size) || parser.parseColon() || + parser.parseInteger(stride) || parser.parseRSquare()); +} + +static ParseResult parseDimTriple(AsmParser& parser, int64_t& dim, + int64_t& size, int64_t& stride) { + return failure(parser.parseLSquare() || + failed(parseDimTripleAfterLSquare(parser, dim, size, stride))); +} + +static void printDimTriple(AsmPrinter& printer, DimAttr dim) { + printer << "[" << dim.getDim() << ":" << dim.getSize() << ":" + << dim.getStride() << "]"; +} + +static ParseResult parseLayoutDims(AsmParser& parser, + SmallVector& dims) { + if (parser.parseLSquare()) return failure(); + if (succeeded(parser.parseOptionalRSquare())) return success(); + + while (true) { + if (succeeded(parser.parseOptionalLSquare())) { + int64_t dim; + int64_t size; + int64_t stride; + if (failed(parseDimTripleAfterLSquare(parser, dim, size, stride))) + return failure(); + dims.push_back(DimAttr::get(parser.getContext(), dim, size, stride)); + } else { + Attribute dim; + if (parser.parseAttribute(dim)) return failure(); + if (!isa(dim)) { + return parser.emitError(parser.getNameLoc()) + << "expected a #rotom.dim attribute"; + } + dims.push_back(dim); } - seenDim.insert(d.getDim()); + + if (succeeded(parser.parseOptionalComma())) continue; + return parser.parseRSquare(); } - if (allUnique && data.traversalDims.size() > 1) { - llvm::SmallVector> byDim; - byDim.reserve(data.traversalDims.size()); - for (int64_t i = 0; i < static_cast(data.traversalDims.size()); - ++i) { - byDim.push_back({data.traversalDims[i].getDim(), i}); - } - llvm::sort(byDim, - [](const auto& a, const auto& b) { return a.first < b.first; }); - - llvm::SmallVector reorderedTraversal; - reorderedTraversal.reserve(data.traversalDims.size()); - llvm::SmallVector oldToNew(data.traversalDims.size(), 0); - for (int64_t newIdx = 0; newIdx < static_cast(byDim.size()); - ++newIdx) { - const int64_t oldIdx = byDim[newIdx].second; - oldToNew[oldIdx] = newIdx; - reorderedTraversal.push_back(data.traversalDims[oldIdx]); +} + +static ParseResult parseLayoutRolls(AsmParser& parser, + SmallVector& rolls) { + if (parser.parseLSquare()) return failure(); + if (succeeded(parser.parseOptionalRSquare())) return success(); + + while (true) { + if (succeeded(parser.parseOptionalLParen())) { + int64_t from; + int64_t to; + if (parser.parseInteger(from) || parser.parseComma() || + parser.parseInteger(to) || parser.parseRParen()) + return failure(); + rolls.push_back(from); + rolls.push_back(to); + } else { + int64_t value; + if (parser.parseInteger(value)) return failure(); + rolls.push_back(value); } - data.traversalDims = std::move(reorderedTraversal); - for (size_t p = 0; p < data.pieces.size(); ++p) { - if (data.pieces[p] == LayoutPieceKind::Traversal) { - data.pieceIndex[p] = oldToNew[data.pieceIndex[p]]; - } + if (succeeded(parser.parseOptionalComma())) continue; + return parser.parseRSquare(); + } +} + +} // namespace + +static LogicalResult verifyLayoutRolls( + ArrayAttr dims, DenseI64ArrayAttr rolls, + function_ref emitError) { + if (!rolls) return success(); + ArrayRef r = rolls.asArrayRef(); + if (r.empty()) return success(); + if (r.size() % 2 != 0) { + return emitError() << "rolls must contain an even number of integers " + "(pairs of dim indices)"; + } + + for (size_t i = 0; i < r.size(); i += 2) { + const int64_t ti = r[i]; + const int64_t tj = r[i + 1]; + if (ti == tj) { + return emitError() << "each roll must use two distinct dim indices"; + } + if (ti < 0 || tj < 0 || ti >= static_cast(dims.size()) || + tj >= static_cast(dims.size())) { + return emitError() << "roll dim index out of bounds for dims list"; + } + auto di = dyn_cast(dims[ti]); + auto dj = dyn_cast(dims[tj]); + if (!di || !dj) { + return emitError() << "roll indices must refer to #rotom.dim entries"; + } + if (di.getSize() != dj.getSize()) { + return emitError() << "rolled dims must have the same extent (size)"; + } + if (di.isGap() || dj.isGap() || di.isReplicate() || dj.isReplicate()) { + return emitError() << "rolls may only reference non-sentinel traversal " + "dims (dim >= 0)"; } } + return success(); +} - return data; +void DimAttr::print(AsmPrinter& printer) const { + printer << "<"; + printDimTriple(printer, *this); + printer << ">"; } -} // namespace +Attribute DimAttr::parse(AsmParser& parser, Type type) { + int64_t dim; + int64_t size; + int64_t stride; + + if (parser.parseLess() || failed(parseDimTriple(parser, dim, size, stride)) || + parser.parseGreater()) { + return {}; + } + + return DimAttr::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, + parser.getContext(), dim, size, stride); +} LogicalResult DimAttr::verify(function_ref emitError, int64_t dim, int64_t size, int64_t stride) { @@ -158,13 +251,85 @@ LogicalResult DimAttr::verify(function_ref emitError, return success(); } +void LayoutAttr::print(AsmPrinter& printer) const { + printer << " values = rolls.asArrayRef(); + printer << ", rolls = ["; + for (size_t i = 0; i < values.size(); i += 2) { + if (i != 0) printer << ", "; + printer << "(" << values[i] << ", " << values[i + 1] << ")"; + } + printer << "]"; + } + + printer << ", dims = ["; + llvm::interleaveComma(getDims(), printer, [&](Attribute attr) { + printDimTriple(printer, cast(attr)); + }); + printer << "]>"; +} + +Attribute LayoutAttr::parse(AsmParser& parser, Type type) { + int64_t n; + SmallVector rolls; + SmallVector dims; + + if (parser.parseLess()) return {}; + + if (succeeded(parser.parseOptionalKeyword("n"))) { + if (parser.parseEqual() || parser.parseInteger(n) || parser.parseComma()) { + return {}; + } + + if (succeeded(parser.parseOptionalKeyword("rolls"))) { + if (parser.parseEqual() || failed(parseLayoutRolls(parser, rolls)) || + parser.parseComma()) { + return {}; + } + } + + if (parser.parseKeyword("dims") || parser.parseEqual() || + failed(parseLayoutDims(parser, dims)) || parser.parseGreater()) { + return {}; + } + } else if (succeeded(parser.parseOptionalKeyword("dims"))) { + if (parser.parseEqual() || failed(parseLayoutDims(parser, dims)) || + parser.parseComma() || parser.parseKeyword("n") || + parser.parseEqual() || parser.parseInteger(n)) { + return {}; + } + + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseKeyword("rolls") || parser.parseEqual() || + failed(parseLayoutRolls(parser, rolls))) { + return {}; + } + } + + if (parser.parseGreater()) return {}; + } else { + parser.emitError(parser.getNameLoc()) + << "expected `n` or `dims` in rotom layout"; + return {}; + } + + MLIRContext* context = parser.getContext(); + return LayoutAttr::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, context, + ArrayAttr::get(context, dims), n, DenseI64ArrayAttr::get(context, rolls)); +} + FailureOr preprocessLayoutAttr(LayoutAttr layout) { return preprocessLayoutData(layout.getDims(), layout.getN(), layout.getContext()); } LogicalResult LayoutAttr::verify(function_ref emitError, - ArrayAttr dims, int64_t n) { + ArrayAttr dims, int64_t n, + DenseI64ArrayAttr rolls) { if (n <= 0) { return emitError() << "`n` must be > 0, got " << n; } @@ -173,6 +338,8 @@ LogicalResult LayoutAttr::verify(function_ref emitError, return emitError() << "`dims` must be an array of `#rotom.dim<...>`"; } + if (failed(verifyLayoutRolls(dims, rolls, emitError))) return failure(); + MLIRContext* ctx = dims.getContext(); std::vector ctDims; std::vector slotDims; @@ -261,12 +428,17 @@ LogicalResult SeedAttr::verify(function_ref emitError, return emitError() << "seed layouts must be `rotom.layout` attributes"; } if (failed(LayoutAttr::verify(emitError, layoutAttr.getDims(), - layoutAttr.getN()))) + layoutAttr.getN(), layoutAttr.getRolls()))) return failure(); } return success(); } +LayoutAttr LayoutAttr::get(MLIRContext* context, ArrayAttr dims, int64_t n) { + return get(context, dims, n, + DenseI64ArrayAttr::get(context, ArrayRef{})); +} + } // namespace rotom } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Rotom/IR/RotomAttributes.td b/lib/Dialect/Rotom/IR/RotomAttributes.td index 20cd06a6fb..b547d0ab07 100644 --- a/lib/Dialect/Rotom/IR/RotomAttributes.td +++ b/lib/Dialect/Rotom/IR/RotomAttributes.td @@ -10,7 +10,6 @@ include "mlir/IR/OpAsmInterface.td" class Rotom_Attr traits = []> : AttrDef { let mnemonic = attrMnemonic; - let assemblyFormat = "`<` struct(params) `>`"; let genMnemonicAlias = 1; } @@ -39,6 +38,7 @@ def Rotom_DimAttr : Rotom_Attr<"Dim", "dim"> { }]; let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; } def Rotom_LayoutAttr : Rotom_Attr<"Layout", "layout"> { @@ -54,14 +54,27 @@ def Rotom_LayoutAttr : Rotom_Attr<"Layout", "layout"> { For tensor_ext materialization, the **first** entry in `dims` is the ciphertext side of Rotom's `;` split (one piece); remaining entries are in-slot. See [Section 4.2 of the Rotom paper](https://eprint.iacr.org/2025/1319.pdf). + + Optional **rolls** encode a `roll(i,j)` metadata object: each pair `(i, j)` + indexes into the `dims` array (the flattened `ct_dims + slot_dims` list) and + uses modular addition to modify the indices of `dims[i]` by the indices of + `dims[j]`. }]; let parameters = (ins "::mlir::ArrayAttr":$dims, - "int64_t":$n + "int64_t":$n, + OptionalParameter<"::mlir::DenseI64ArrayAttr">:$rolls ); let genVerifyDecl = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Layout with no `roll(i,j)` metadata (empty rolls storage). + static ::mlir::heir::rotom::LayoutAttr get(::mlir::MLIRContext *context, + ::mlir::ArrayAttr dims, int64_t n); + }]; } def Rotom_SeedAttr : Rotom_Attr<"Seed", "seed"> { @@ -75,6 +88,7 @@ def Rotom_SeedAttr : Rotom_Attr<"Seed", "seed"> { ); let genVerifyDecl = 1; + let assemblyFormat = "`<` struct(params) `>`"; } #endif // LIB_DIALECT_ROTOM_IR_ROTOMATTRIBUTES_TD_ diff --git a/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLowering.cpp b/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLowering.cpp index afae5319cf..bdc78a0528 100644 --- a/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLowering.cpp +++ b/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLowering.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include "lib/Dialect/Rotom/IR/RotomAttributes.h" @@ -10,19 +11,43 @@ #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir::heir::rotom { namespace { +/// Maps a `#rotom.dim` from the layout's `dims` list to its iterator index `i*` +/// after preprocessing (match logical axis, size, and stride). +static FailureOr traversalIndexForRotomDim( + const SmallVector& traversalDims, DimAttr want) { + for (int64_t i = 0; i < static_cast(traversalDims.size()); ++i) { + if (traversalDims[i].getDim() == want.getDim() && + traversalDims[i].getSize() == want.getSize() && + traversalDims[i].getStride() == want.getStride()) { + return i; + } + } + return failure(); +} + +static std::string modExpr(llvm::StringRef expr, int64_t mod) { + std::string out; + llvm::raw_string_ostream os(out); + os << "(" << expr << " - " << mod << " * floor((" << expr << ") / " << mod + << "))"; + return out; +} + static LogicalResult emitSegmentAddress( llvm::raw_ostream& os, bool& firstTerm, ArrayRef pieces, ArrayRef pieceIndex, const SmallVector& traversalDims, const SmallVector& gapDims, const SmallVector& replicationDims, int64_t numActiveTraversalComponents, size_t segStart, size_t segEnd, - bool foldGapVarsToZero) { + bool foldGapVarsToZero, ArrayRef rolls, ArrayAttr rotomDims, + bool isSlotLine) { llvm::SmallVector suffixCoeff(pieces.size(), 0); int64_t suffix = 1; for (size_t p = segEnd; p > segStart;) { @@ -40,39 +65,77 @@ static LogicalResult emitSegmentAddress( d = replicationDims[pieceIndex[p]]; break; } - suffix *= d.getSize(); } - auto emitTerm = [&](int64_t coeff, llvm::StringRef var) -> LogicalResult { + auto emitTerm = [&](int64_t coeff, llvm::StringRef expr) -> LogicalResult { if (coeff == 0) return failure(); - if (!firstTerm) os << " + "; - firstTerm = false; - if (coeff == 1) { - os << var; + if (firstTerm) { + if (coeff < 0) os << "-"; + firstTerm = false; } else { - os << coeff << " * " << var; + os << (coeff < 0 ? " - " : " + "); + } + const int64_t absCoeff = std::llabs(coeff); + if (absCoeff == 1) { + os << expr; + } else { + os << absCoeff << " * " << expr; } return success(); }; - llvm::DenseMap traversalCoeff; llvm::DenseMap gapCoeff; llvm::DenseMap replicationCoeff; for (size_t p = segStart; p < segEnd; ++p) { const int64_t coeff = suffixCoeff[p]; - - if (pieces[p] == LayoutPieceKind::Traversal) { - const int64_t ti = pieceIndex[p]; - if (traversalDims[ti].getSize() == 1) continue; - traversalCoeff[ti] = coeff; - } else if (pieces[p] == LayoutPieceKind::Gap) { + if (pieces[p] == LayoutPieceKind::Gap) { if (foldGapVarsToZero) continue; - const int64_t gk = pieceIndex[p]; - gapCoeff[gk] = coeff; - } else { - const int64_t ek = pieceIndex[p]; - replicationCoeff[ek] = coeff; + gapCoeff[pieceIndex[p]] = coeff; + } else if (pieces[p] == LayoutPieceKind::Replication) { + replicationCoeff[pieceIndex[p]] = coeff; + } + } + + llvm::DenseMap traversalCoeff; + for (size_t p = segStart; p < segEnd; ++p) { + if (pieces[p] != LayoutPieceKind::Traversal) continue; + const int64_t ti = pieceIndex[p]; + if (traversalDims[ti].getSize() == 1) continue; + traversalCoeff[ti] = suffixCoeff[p]; + } + + llvm::SmallVector traversalExprs; + traversalExprs.reserve(traversalDims.size()); + for (int64_t i = 0; i < static_cast(traversalDims.size()); ++i) { + traversalExprs.push_back("i" + std::to_string(i)); + } + + // Apply roll(a,b) transforms left-to-right: + // t_a <- (t_a - t_b) mod extent(a). + if (isSlotLine && !rolls.empty()) { + if (!rotomDims || rolls.size() % 2 != 0) return failure(); + for (size_t i = 0; i < rolls.size(); i += 2) { + const int64_t fromIdx = rolls[i]; + const int64_t toIdx = rolls[i + 1]; + if (fromIdx < 0 || toIdx < 0 || + fromIdx >= static_cast(rotomDims.size()) || + toIdx >= static_cast(rotomDims.size())) { + return failure(); + } + auto fromDim = dyn_cast(rotomDims[fromIdx]); + auto toDim = dyn_cast(rotomDims[toIdx]); + if (!fromDim || !toDim) return failure(); + FailureOr maybeFromTrav = + traversalIndexForRotomDim(traversalDims, fromDim); + FailureOr maybeToTrav = + traversalIndexForRotomDim(traversalDims, toDim); + if (failed(maybeFromTrav) || failed(maybeToTrav)) return failure(); + const int64_t fromTrav = *maybeFromTrav; + const int64_t toTrav = *maybeToTrav; + std::string diffExpr = + "(" + traversalExprs[fromTrav] + " - " + traversalExprs[toTrav] + ")"; + traversalExprs[fromTrav] = modExpr(diffExpr, fromDim.getSize()); } } @@ -81,7 +144,7 @@ static LogicalResult emitSegmentAddress( if (traversalDims[oldIdx].getSize() == 1) continue; auto it = traversalCoeff.find(oldIdx); if (it != traversalCoeff.end()) { - if (failed(emitTerm(it->second, "i" + std::to_string(oldIdx)))) + if (failed(emitTerm(it->second, traversalExprs[oldIdx]))) return failure(); } } @@ -107,7 +170,8 @@ static FailureOr emitSplitCtSlotIsl( ArrayRef pieceIndex, const SmallVector& traversalDims, const SmallVector& replicationDims, const SmallVector& gapDims, int64_t numTraversalComponents, - int64_t numReplication, int64_t numGap) { + int64_t numReplication, int64_t numGap, ArrayRef rolls, + ArrayAttr rotomDims) { if (prefix > pieces.size()) return failure(); int64_t numCt = 1; @@ -170,7 +234,8 @@ static FailureOr emitSplitCtSlotIsl( if (failed(emitSegmentAddress(os, firstTerm, pieces, pieceIndex, traversalDims, gapDims, replicationDims, numTraversalComponents, 0, prefix, - foldGapVarsToZero))) + foldGapVarsToZero, rolls, rotomDims, + /*isSlotLine=*/false))) return failure(); if (firstTerm) os << "0"; @@ -180,7 +245,8 @@ static FailureOr emitSplitCtSlotIsl( if (failed(emitSegmentAddress(os, firstTerm, pieces, pieceIndex, traversalDims, gapDims, replicationDims, numTraversalComponents, prefix, pieces.size(), - foldGapVarsToZero))) + foldGapVarsToZero, rolls, rotomDims, + /*isSlotLine=*/true))) return failure(); if (firstTerm) os << "0"; @@ -215,10 +281,13 @@ static FailureOr lowerToIslImpl(LayoutAttr layout) { const int64_t numReplication = static_cast(data.replicationDims.size()); const int64_t numGap = static_cast(data.gapDims.size()); - return emitSplitCtSlotIsl(data.n, data.ctPrefixLen, data.pieces, - data.pieceIndex, data.traversalDims, - data.replicationDims, data.gapDims, - numTraversalComponents, numReplication, numGap); + DenseI64ArrayAttr rollsAttr = layout.getRolls(); + ArrayRef rolls = + rollsAttr ? rollsAttr.asArrayRef() : ArrayRef{}; + return emitSplitCtSlotIsl( + data.n, data.ctPrefixLen, data.pieces, data.pieceIndex, + data.traversalDims, data.replicationDims, data.gapDims, + numTraversalComponents, numReplication, numGap, rolls, layout.getDims()); } } // namespace diff --git a/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLoweringTest.cpp b/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLoweringTest.cpp index 620e8979eb..0b70e2f6ab 100644 --- a/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLoweringTest.cpp +++ b/lib/Dialect/Rotom/Utils/RotomTensorExtLayoutLoweringTest.cpp @@ -75,8 +75,11 @@ TEST(RotomTensorExtLayoutLoweringTest, ColumnMajor4x4Evaluate) { {9, 10, 11, 12}, {13, 14, 15, 16}, }; - std::vector> packed = - evaluateLayoutOnMatrix(relation.value(), matrix); + std::vector> packed = evaluateLayout( + relation.value(), [&](const std::vector& domainPoint) -> int { + // Traversal dims are {dim1, dim0}, so relation vars are [col, row]. + return matrix[domainPoint[1]][domainPoint[0]]; + }); std::vector> expected = { {1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 4, 8, 12, 16}, @@ -177,8 +180,11 @@ TEST(RotomTensorExtLayoutLoweringTest, SplitColumnMajor4x4Evaluate) { {9, 10, 11, 12}, {13, 14, 15, 16}, }; - std::vector> packed = - evaluateLayoutOnMatrix(relation.value(), matrix); + std::vector> packed = evaluateLayout( + relation.value(), [&](const std::vector& domainPoint) -> int { + // Traversal dims are {dim1, dim0}, so relation vars are [col, row]. + return matrix[domainPoint[1]][domainPoint[0]]; + }); // Column-major packing, split into 4 ciphertexts of 4 slots: one column per // ciphertext. @@ -210,7 +216,7 @@ TEST(RotomTensorExtLayoutLoweringTest, PreprocessAddsImplicitGap) { EXPECT_EQ(data->pieces[1], LayoutPieceKind::Traversal); } -TEST(RotomTensorExtLayoutLoweringTest, PreprocessReordersUniqueTraversalDims) { +TEST(RotomTensorExtLayoutLoweringTest, PreprocessPreservesTraversalDimsOrder) { MLIRContext context; context.loadDialect(); DimAttr d0 = DimAttr::get(&context, /*dim=*/0, /*size=*/4, /*stride=*/1); @@ -221,8 +227,101 @@ TEST(RotomTensorExtLayoutLoweringTest, PreprocessReordersUniqueTraversalDims) { FailureOr data = preprocessLayoutAttr(layout); ASSERT_TRUE(succeeded(data)); ASSERT_EQ(data->traversalDims.size(), 2); - EXPECT_EQ(data->traversalDims[0].getDim(), 0); - EXPECT_EQ(data->traversalDims[1].getDim(), 1); + EXPECT_EQ(data->traversalDims[0].getDim(), 1); + EXPECT_EQ(data->traversalDims[1].getDim(), 0); +} + +TEST(RotomTensorExtLayoutLoweringTest, RolledRowMajor2x2Evaluate) { + MLIRContext context; + context.loadDialect(); + DimAttr d0 = DimAttr::get(&context, /*dim=*/0, /*size=*/2, /*stride=*/1); + DimAttr d1 = DimAttr::get(&context, /*dim=*/1, /*size=*/2, /*stride=*/1); + ArrayAttr dims = ArrayAttr::get(&context, {d0, d1}); + LayoutAttr layout = LayoutAttr::get( + &context, dims, /*n=*/4, + DenseI64ArrayAttr::get(&context, ArrayRef{0, 1})); + + FailureOr isl = + RotomTensorExtLayoutLowering::lowerToTensorExtIsl(layout); + ASSERT_TRUE(succeeded(isl)); + auto relation = getIntegerRelationFromIslStr(*isl); + ASSERT_TRUE(succeeded(relation)); + + std::vector> matrix = { + {1, 2}, + {3, 4}, + }; + std::vector> packed = + evaluateLayoutOnMatrix(relation.value(), matrix); + std::vector> expected = {{1, 4, 3, 2}}; + EXPECT_EQ(packed, expected); + EXPECT_EQ(unpackLayoutToMatrix(relation.value(), packed, {2, 2}), matrix); +} + +TEST(RotomTensorExtLayoutLoweringTest, RolledRowMajor4x4Evaluate) { + MLIRContext context; + context.loadDialect(); + DimAttr d0 = DimAttr::get(&context, /*dim=*/0, /*size=*/4, /*stride=*/1); + DimAttr d1 = DimAttr::get(&context, /*dim=*/1, /*size=*/4, /*stride=*/1); + ArrayAttr dims = ArrayAttr::get(&context, {d0, d1}); + LayoutAttr layout = LayoutAttr::get( + &context, dims, /*n=*/16, + DenseI64ArrayAttr::get(&context, ArrayRef{0, 1})); + + FailureOr isl = + RotomTensorExtLayoutLowering::lowerToTensorExtIsl(layout); + ASSERT_TRUE(succeeded(isl)); + auto relation = getIntegerRelationFromIslStr(*isl); + ASSERT_TRUE(succeeded(relation)); + + std::vector> matrix = { + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + }; + std::vector> packed = + evaluateLayoutOnMatrix(relation.value(), matrix); + + // ``roll(0,1)``: diagonal ``(i0 - i1) mod 4`` classes, listed in Rotom order. + std::vector> expected = { + {1, 6, 11, 16, 5, 10, 15, 4, 9, 14, 3, 8, 13, 2, 7, 12}, + }; + EXPECT_EQ(packed, expected); + EXPECT_EQ(unpackLayoutToMatrix(relation.value(), packed, {4, 4}), matrix); +} + +TEST(RotomTensorExtLayoutLoweringTest, RolledInternalRowMajor4x4Evaluate) { + MLIRContext context; + context.loadDialect(); + DimAttr d0 = DimAttr::get(&context, /*dim=*/0, /*size=*/4, /*stride=*/1); + DimAttr d1 = DimAttr::get(&context, /*dim=*/1, /*size=*/4, /*stride=*/1); + ArrayAttr dims = ArrayAttr::get(&context, {d0, d1}); + LayoutAttr layout = LayoutAttr::get( + &context, dims, /*n=*/16, + DenseI64ArrayAttr::get(&context, ArrayRef{1, 0})); + + FailureOr isl = + RotomTensorExtLayoutLowering::lowerToTensorExtIsl(layout); + ASSERT_TRUE(succeeded(isl)); + auto relation = getIntegerRelationFromIslStr(*isl); + ASSERT_TRUE(succeeded(relation)); + + std::vector> matrix = { + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + }; + std::vector> packed = + evaluateLayoutOnMatrix(relation.value(), matrix); + + // ``roll(1,0)``: cyclic column order within each row (dims index high first). + std::vector> expected = { + {1, 2, 3, 4, 6, 7, 8, 5, 11, 12, 9, 10, 16, 13, 14, 15}, + }; + EXPECT_EQ(packed, expected); + EXPECT_EQ(unpackLayoutToMatrix(relation.value(), packed, {4, 4}), matrix); } } // namespace diff --git a/tests/Dialect/Rotom/IR/layout.mlir b/tests/Dialect/Rotom/IR/layout.mlir index 5077c990e2..6e925f3c6c 100644 --- a/tests/Dialect/Rotom/IR/layout.mlir +++ b/tests/Dialect/Rotom/IR/layout.mlir @@ -1,37 +1,43 @@ // RUN: heir-opt --verify-diagnostics --split-input-file %s -#d0 = #rotom.dim -#d1 = #rotom.dim +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> // A simple 4x4 layout with 16 slots should have pow2 slot dims. -#layout_ok = #rotom.layout +#layout_ok = #rotom.layout func.func private @ok(tensor<16xi32> {foo.bar = #layout_ok}) // ----- // This forces a non-pow2 slot dim size after splitting: size=3 divides 12 but not pow2. -// expected-error @below {{slot dim size must be a power of two, got 3}} -#bad = #rotom.layout], n = 12> +#bad = #rotom.layout // expected-error {{slot dim size must be a power of two, got 3}} func.func private @bad(tensor<16xi32> {foo.bar = #bad}) // ----- // Splitting case: size > n causes a ct/slot split, and slot-side size becomes n (must be pow2). -#split_ok = #rotom.layout], n = 8> +#split_ok = #rotom.layout func.func private @split_ok(tensor<16xi32> {foo.bar = #split_ok}) // ----- // size > n but not divisible => verifier error (mirrors Python assert size % n == 0). -// expected-error @below {{dim size 10 must be divisible by remaining slot capacity 8}} -#split_bad = #rotom.layout], n = 8> +#split_bad = #rotom.layout // expected-error {{dim size 10 must be divisible by remaining slot capacity 8}} func.func private @split_bad(tensor<16xi32> {foo.bar = #split_bad}) // ----- -#d0 = #rotom.dim -#d1 = #rotom.dim +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> // n == 0 is invalid. -// expected-error @below {{`n` must be > 0, got 0}} -#n0_ok = #rotom.layout +#n0_ok = #rotom.layout // expected-error {{`n` must be > 0, got 0}} func.func private @n0_ok(tensor<16xi32> {foo.bar = #n0_ok}) + +// ----- + +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:2:1]> + +// Mismatched extents for a roll pair. +#bad_roll = #rotom.layout // expected-error {{rolled dims must have the same extent (size)}} +func.func private @bad_roll(tensor<16xi32> {foo.bar = #bad_roll}) diff --git a/tests/Dialect/Rotom/IR/syntax.mlir b/tests/Dialect/Rotom/IR/syntax.mlir index cbbb56ca7e..325c704b18 100644 --- a/tests/Dialect/Rotom/IR/syntax.mlir +++ b/tests/Dialect/Rotom/IR/syntax.mlir @@ -1,10 +1,23 @@ -// RUN: heir-opt %s +// RUN: heir-opt %s | FileCheck %s -#d0 = #rotom.dim -#d1 = #rotom.dim -#layout = #rotom.layout -module { - func.func @f(%arg0: tensor<4x4xf32> {rotom.layout = #layout}) { +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#plain = #rotom.layout +#rolled = #rotom.layout + +// CHECK: #dim = #rotom.dim<[2:8:4]> +// CHECK: #layout = #rotom.layout +// CHECK: #layout1 = #rotom.layout +// CHECK: module attributes +// CHECK-SAME: rotom.dim_attr = #dim +// CHECK-SAME: rotom.plain_layout = #layout +// CHECK-SAME: rotom.rolled_layout = #layout1 +module attributes { + rotom.dim_attr = #rotom.dim<[2:8:4]>, + rotom.plain_layout = #plain, + rotom.rolled_layout = #rolled +} { + func.func @f(%arg0: tensor<4x4xf32>) { return } } diff --git a/tests/Dialect/Rotom/IR/verifier.mlir b/tests/Dialect/Rotom/IR/verifier.mlir index 1bb5290025..7211afb804 100644 --- a/tests/Dialect/Rotom/IR/verifier.mlir +++ b/tests/Dialect/Rotom/IR/verifier.mlir @@ -1,14 +1,14 @@ // RUN: heir-opt --verify-diagnostics --split-input-file %s // expected-error @below {{`size` must be > 0, got 0}} -func.func private @bad_size(tensor<16xi32> {foo = #rotom.dim}) +func.func private @bad_size(tensor<16xi32> {foo = #rotom.dim<[0:0:1]>}) // ----- // expected-error @below {{`stride` must be > 0, got 0}} -func.func private @bad_stride(tensor<16xi32> {foo = #rotom.dim}) +func.func private @bad_stride(tensor<16xi32> {foo = #rotom.dim<[0:8:0]>}) // ----- // expected-error @below {{`dim` must be >= -2, got -3}} -func.func private @bad_dim(tensor<16xi32> {foo = #rotom.dim}) +func.func private @bad_dim(tensor<16xi32> {foo = #rotom.dim<[-3:8:1]>}) diff --git a/tests/Dialect/Rotom/Transforms/doctest.mlir b/tests/Dialect/Rotom/Transforms/doctest.mlir index abf18e3984..6f8ca1cebe 100644 --- a/tests/Dialect/Rotom/Transforms/doctest.mlir +++ b/tests/Dialect/Rotom/Transforms/doctest.mlir @@ -5,9 +5,9 @@ // arguments, region arguments, and op results via the attribute-association // rules used by `findAttributeAssociatedWith`. -#d0 = #rotom.dim -#d1 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4x4xf32> {tensor_ext.layout = // CHECK: arith.constant {tensor_ext.layout = diff --git a/tests/Dialect/Rotom/Transforms/doctest_seed_layout.mlir b/tests/Dialect/Rotom/Transforms/doctest_seed_layout.mlir index 645e6eeb00..8ae956689b 100644 --- a/tests/Dialect/Rotom/Transforms/doctest_seed_layout.mlir +++ b/tests/Dialect/Rotom/Transforms/doctest_seed_layout.mlir @@ -1,7 +1,18 @@ // RUN: heir-opt %s --rotom-seed-layout=n=8 | FileCheck %s module { - // CHECK: func.func @test_seeding(%{{.*}}: !secret.secret> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}, %{{.*}}: tensor<4x4xf32> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}) + // CHECK: func.func @test_seeding( + // CHECK-SAME: !secret.secret> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}, %{{.*}}: tensor<4x4xf32> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}) func.func @test_seeding(%arg0: !secret.secret>, %arg1: tensor<4x4xf32>) -> !secret.secret> { // CHECK: secret.generic(%{{.*}}: !secret.secret>) %0 = secret.generic(%arg0 : !secret.secret>) { diff --git a/tests/Dialect/Rotom/Transforms/materialize_column_major.mlir b/tests/Dialect/Rotom/Transforms/materialize_column_major.mlir index 2b4507ad1e..55ad6fe353 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_column_major.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_column_major.mlir @@ -1,15 +1,15 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s // Rotom ``[1:4:1][0:4:1]`` with ``n = 16``: column-major in-slot after `;`. -#d0 = #rotom.dim -#d1 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4x4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< // CHECK-DAG: [i0, i1] -> [ct, slot] : // CHECK-DAG: ct = 0 -// CHECK-DAG: slot = i0 + 4 * i1 +// CHECK-DAG: slot = 4 * i0 + i1 module { func.func @f(%arg0: tensor<4x4xf32> {rotom.layout = #layout}) { return diff --git a/tests/Dialect/Rotom/Transforms/materialize_ct_slot_split.mlir b/tests/Dialect/Rotom/Transforms/materialize_ct_slot_split.mlir index fca07783e1..fca9a6eb35 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_ct_slot_split.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_ct_slot_split.mlir @@ -3,11 +3,11 @@ // Rotom ``[0:2:2];[1:2:2][0:2:1][1:2:1]`` with ``n = 8``: first dim is // ciphertext traversal, remaining dims pack within each CT. This is a tiled // row-major layout. -#d0 = #rotom.dim -#d1 = #rotom.dim -#d2 = #rotom.dim -#d3 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:2:2]> +#d1 = #rotom.dim<[1:2:2]> +#d2 = #rotom.dim<[0:2:1]> +#d3 = #rotom.dim<[1:2:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<2x2x2x2xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_explicit_gap_dim.mlir b/tests/Dialect/Rotom/Transforms/materialize_explicit_gap_dim.mlir index c3fb2852fd..78c8ce2867 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_explicit_gap_dim.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_explicit_gap_dim.mlir @@ -1,9 +1,9 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s // Rotom ``[0:4:1][G:2:1]`` with ``n = 8``: in-slot row with explicit gap. -#d0 = #rotom.dim -#g0 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#g0 = #rotom.dim<[-2:2:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_implicit_gap_dim.mlir b/tests/Dialect/Rotom/Transforms/materialize_implicit_gap_dim.mlir index e06d366b1f..d0b1659c18 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_implicit_gap_dim.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_implicit_gap_dim.mlir @@ -3,8 +3,8 @@ // Rotom ``[G:2:1][0:4:1]`` with ``n = 8``: Row-major, first 4. Implicit gap // dimension, ``[G:2:1]``, should be added in front. -#d0 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_multi_dim.mlir b/tests/Dialect/Rotom/Transforms/materialize_multi_dim.mlir index 1fcf4354ad..9e0cb334cb 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_multi_dim.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_multi_dim.mlir @@ -1,9 +1,9 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s -#d0 = #rotom.dim -#d1 = #rotom.dim -#d2 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:2:1]> +#d1 = #rotom.dim<[1:2:1]> +#d2 = #rotom.dim<[2:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<2x2x4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_non_contiguous_dim_ids.mlir b/tests/Dialect/Rotom/Transforms/materialize_non_contiguous_dim_ids.mlir index fe21e4dbeb..c765073d89 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_non_contiguous_dim_ids.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_non_contiguous_dim_ids.mlir @@ -1,8 +1,8 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s -#d0 = #rotom.dim -#d2 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:2:1]> +#d2 = #rotom.dim<[2:2:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<2x2x2xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_repeated_column_major.mlir b/tests/Dialect/Rotom/Transforms/materialize_repeated_column_major.mlir index 22135428fa..c23a4fb92e 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_repeated_column_major.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_repeated_column_major.mlir @@ -2,18 +2,18 @@ // Rotom ``[R:4:1];[1:4:1][0:4:1]`` with ``n = 16`` (replication + column-major // traversals in ``dims``). Replication is projected to ciphertext index ``ct`` -// via existential ``d2``; slots pack ``i0 + 4 * i1``. -#d0 = #rotom.dim -#d1 = #rotom.dim -#d2 = #rotom.dim -#layout = #rotom.layout +// via existential ``d2``; slots pack ``4 * i0 + i1``. +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#d2 = #rotom.dim<[-1:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4x4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< // CHECK-DAG: [i0, i1] -> [ct, slot] : // CHECK-DAG: exists d2 // CHECK-DAG: ct = d2 -// CHECK-DAG: slot = i0 + 4 * i1 +// CHECK-DAG: slot = 4 * i0 + i1 module { func.func @f(%arg0: tensor<4x4xf32> {rotom.layout = #layout}) { return diff --git a/tests/Dialect/Rotom/Transforms/materialize_replication_dim.mlir b/tests/Dialect/Rotom/Transforms/materialize_replication_dim.mlir index ffe27666f2..a815c5de32 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_replication_dim.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_replication_dim.mlir @@ -2,9 +2,9 @@ // Rotom ``[0:4:1][R:2:1]`` with ``n = 8``: row-major where each value is // repeated twice. -#d0 = #rotom.dim -#r0 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#r0 = #rotom.dim<[-1:2:4]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_rolled_row_major.mlir b/tests/Dialect/Rotom/Transforms/materialize_rolled_row_major.mlir new file mode 100644 index 0000000000..2fd65871e2 --- /dev/null +++ b/tests/Dialect/Rotom/Transforms/materialize_rolled_row_major.mlir @@ -0,0 +1,18 @@ +// RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s + +// Rotom ``roll(0, 1)`` on a 4x4 row-major layout: slot indices group +// diagonal ``(i0 - i1) mod 4`` classes in row-major slot order. +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#layout = #rotom.layout + +// CHECK: func.func @f(%arg0: tensor<4x4xf32> {tensor_ext.layout = +// CHECK-DAG: #tensor_ext.layout< +// CHECK-DAG: [i0, i1] -> [ct, slot] : +// CHECK-DAG: ct = 0 +// CHECK-DAG: slot = 4 * ((i0 - i1) - 4 * floor(((i0 - i1)) / 4)) + i1 +module { + func.func @f(%arg0: tensor<4x4xf32> {rotom.layout = #layout}) { + return + } +} diff --git a/tests/Dialect/Rotom/Transforms/materialize_row_major.mlir b/tests/Dialect/Rotom/Transforms/materialize_row_major.mlir index e1460ad85f..feab9a0b81 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_row_major.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_row_major.mlir @@ -1,9 +1,9 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s // Rotom ``[0:4:1][1:4:1]`` with ``n = 16``: row-major. -#d0 = #rotom.dim -#d1 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#d1 = #rotom.dim<[1:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4x4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_slot_gap_dims.mlir b/tests/Dialect/Rotom/Transforms/materialize_slot_gap_dims.mlir index 1ebdaedbbe..69f5ec07f7 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_slot_gap_dims.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_slot_gap_dims.mlir @@ -2,11 +2,11 @@ // Two tensor axes (2x2) plus explicit gap dims: linear address includes gap // variables with g_k = 0 (payload only at gap index 0; other indices zero-fill). -#d0 = #rotom.dim -#d1 = #rotom.dim -#g0 = #rotom.dim -#g1 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:2:1]> +#d1 = #rotom.dim<[1:2:2]> +#g0 = #rotom.dim<[-2:2:1]> +#g1 = #rotom.dim<[-2:2:4]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<2x2xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_tiled_duplicate_dim.mlir b/tests/Dialect/Rotom/Transforms/materialize_tiled_duplicate_dim.mlir index 70b2155eaf..fcbffd72cb 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_tiled_duplicate_dim.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_tiled_duplicate_dim.mlir @@ -3,11 +3,11 @@ // Tiled row-major style: dim ids repeat across traversals, but (dim, stride) // pairs are distinct. With ``n = 8``, Rotom ``;`` split: first traversal in // ``dims`` is ciphertext index, the rest pack within each CT. -#d0s2 = #rotom.dim -#d1s2 = #rotom.dim -#d0s1 = #rotom.dim -#d1s1 = #rotom.dim -#layout = #rotom.layout +#d0s2 = #rotom.dim<[0:2:2]> +#d1s2 = #rotom.dim<[1:2:2]> +#d0s1 = #rotom.dim<[0:2:1]> +#d1s1 = #rotom.dim<[1:2:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<2x2x2x2xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/materialize_vector.mlir b/tests/Dialect/Rotom/Transforms/materialize_vector.mlir index 7df2e0909b..9b079e13a3 100644 --- a/tests/Dialect/Rotom/Transforms/materialize_vector.mlir +++ b/tests/Dialect/Rotom/Transforms/materialize_vector.mlir @@ -1,7 +1,7 @@ // RUN: heir-opt %s --rotom-materialize-tensor-ext-layout | FileCheck %s -#d0 = #rotom.dim -#layout = #rotom.layout +#d0 = #rotom.dim<[0:4:1]> +#layout = #rotom.layout // CHECK: func.func @f(%arg0: tensor<4xf32> {tensor_ext.layout = // CHECK-DAG: #tensor_ext.layout< diff --git a/tests/Dialect/Rotom/Transforms/seed_layout.mlir b/tests/Dialect/Rotom/Transforms/seed_layout.mlir index 19fc07ade0..b4c295ad85 100644 --- a/tests/Dialect/Rotom/Transforms/seed_layout.mlir +++ b/tests/Dialect/Rotom/Transforms/seed_layout.mlir @@ -1,7 +1,18 @@ // RUN: heir-opt %s --rotom-seed-layout=n=8 --split-input-file | FileCheck %s module { - // CHECK: func.func @test_seeding(%{{.*}}: !secret.secret> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}, %{{.*}}: tensor<4x4xf32> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}) + // CHECK: func.func @test_seeding( + // CHECK-SAME: !secret.secret> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}, %{{.*}}: tensor<4x4xf32> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}) func.func @test_seeding(%arg0: !secret.secret>, %arg1: tensor<4x4xf32>) -> !secret.secret> { // CHECK: secret.generic(%{{.*}}: !secret.secret>) %0 = secret.generic(%arg0 : !secret.secret>) { @@ -16,7 +27,15 @@ module { // ----- module { - // CHECK: func.func @test_seeding_3d(%{{.*}}: !secret.secret> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}) + // CHECK: func.func @test_seeding_3d( + // CHECK-SAME: !secret.secret> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}) func.func @test_seeding_3d(%arg0: !secret.secret>) -> !secret.secret> { // CHECK: secret.generic(%{{.*}}: !secret.secret>) %0 = secret.generic(%arg0 : !secret.secret>) { @@ -30,7 +49,18 @@ module { // ----- module { - // CHECK: func.func @test_seeding_non_pow2(%{{.*}}: !secret.secret> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}, %{{.*}}: tensor<3x3xf32> {rotom.seed = #rotom.seed, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>, #rotom.layout, #rotom.dim, #rotom.dim], n = 8>]>}) + // CHECK: func.func @test_seeding_non_pow2( + // CHECK-SAME: !secret.secret> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}, %{{.*}}: tensor<3x3xf32> {rotom.seed = #rotom.seed + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: #rotom.layout + // CHECK-SAME: ]>}) func.func @test_seeding_non_pow2(%arg0: !secret.secret>, %arg1: tensor<3x3xf32>) -> !secret.secret> { // CHECK: secret.generic(%{{.*}}: !secret.secret>) %0 = secret.generic(%arg0 : !secret.secret>) {