Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions include/TaskflowDialect/TaskflowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ def TaskflowTaskOp : TaskflowOpBase<"task", [
1. Memory dependencies: memrefs that are read or written by the task
2. Value dependencies: SSA values from producer tasks

The `read_memrefs` and `write_memrefs` attributes record the actural
original memrefs that this task accesses,
enabling data placement analysis for multi-CGRA mapping.

Example:
// Memory input: %mem, Value input: %val
// Memory inputs: %mem, Value inputs: %val
$out_mem, %out_val = taskflow.task "Task_0"
memory_inputs(%mem : memref<4xi32>)
value_inputs(%val : i32) {
read_inputs(%mem : memref<4xi32>)
value_inputs(%val : i32)
original_read_memrefs(%arg0 : memref<?x8x6xi32>)
original_write_memrefs(%arg5 : memref<?xi32>) {
^bb0(%a0: memref<4xi32>, %a1: i32):
affine.for %i = 0 to 4 {
%v = affine.load %a0[%i] : memref<4xi32>
Expand All @@ -55,28 +61,22 @@ def TaskflowTaskOp : TaskflowOpBase<"task", [
}];

let arguments = (ins
Variadic<AnyMemRef>:$memory_inputs,
Variadic<AnyMemRef>:$read_memrefs,
Variadic<AnyMemRef>:$write_memrefs,
Variadic<AnyType>:$value_inputs,
StrAttr:$task_name
StrAttr:$task_name,
Variadic<AnyMemRef>:$original_read_memrefs,
Variadic<AnyMemRef>:$original_write_memrefs
);

let results = (outs
Variadic<AnyMemRef>:$memory_outputs,
Variadic<AnyMemRef>:$write_outputs,
Variadic<AnyType>:$value_outputs
);

let regions = (region SizedRegion<1>:$body);

// let hasCustomAssemblyFormat = 1;

// let assemblyFormat = [{
// (`memory_inputs` `(` $memory_inputs^ `:` type($memory_inputs) `)`)?
// (`value_inputs` `(` $value_inputs^ `:` type($value_inputs) `)`)?
// attr-dict-with-keyword
// $body
// `->` `(` type($memory_outputs) `,` type($value_outputs) `)`
// }];

let hasCustomAssemblyFormat = 1;
}

// Defines the yield operation to terminate a Taskflow task.
Expand All @@ -97,13 +97,7 @@ def TaskflowYieldOp : TaskflowOpBase<"yield", [Terminator, Pure, ReturnLike, Att
Variadic<AnyMemRef>:$memory_results,
Variadic<AnyType>:$value_results);

// let assemblyFormat = [{
// (`memory_outputs` `(` $memory_results^ `:` type($memory_results) `)`)?
// (`value_outputs` `(` $value_results^ `:` type($value_results) `)`)?
// attr-dict
// }];

// let hasCustomAssemblyFormat = 1;
let hasCustomAssemblyFormat = 1;

let builders = [
// Default builder for empty yield.
Expand Down
6 changes: 5 additions & 1 deletion include/TaskflowDialect/TaskflowPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ namespace taskflow {
#define GEN_PASS_DECL
#include "TaskflowDialect/TaskflowPasses.h.inc"
std::unique_ptr<mlir::Pass> createConstructHyperblockFromTaskPass();
std::unique_ptr<mlir::Pass> createCanonicalizeTaskPass();
std::unique_ptr<mlir::Pass> createClassifyCountersPass();

//=========================================================//
// Optimization Passes
//=========================================================//
std::unique_ptr<mlir::Pass> createAffineLoopTreeSerializationPass();

#define GEN_PASS_REGISTRATION
#include "TaskflowDialect/TaskflowPasses.h.inc"
} // namespace taskflow
Expand Down
32 changes: 17 additions & 15 deletions include/TaskflowDialect/TaskflowPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,30 @@
include "mlir/Pass/PassBase.td"

//=========================================================//
// Passes for the Taskflow dialect
// Passes for Task Level Optimizations
//=========================================================//
def ConstructHyperblockFromTask : Pass<"construct-hyperblock-from-task", "func::FuncOp"> {
let summary = "Constructs hyperblocks and counter chain from Taskflow tasks";
def AffineLoopTreeSerialization : Pass<"affine-loop-tree-serialization", "ModuleOp">{
let summary = "Serializes top-level affine.for loops into minimized task operations";
let description = [{
This pass constructs hyperblocks and counter chain from Taskflow tasks.
This pass converts top-level affine.for loops in a function into
minimized and canonicalized task operations.
}];
let constructor = "taskflow::createConstructHyperblockFromTaskPass()";
let constructor = "taskflow::createAffineLoopTreeSerializationPass()";
let dependentDialects = [
"mlir::taskflow::TaskflowDialect",
"mlir::affine::AffineDialect",
"mlir::func::FuncDialect"];
}

def CanonicalizeTask: Pass<"canonicalize-task", "func::FuncOp">{
let summary = "Canonicalizes tasks by splitting each hyperblock into a separate atomic task";
//=========================================================//
// Passes for the Taskflow dialect
//=========================================================//
def ConstructHyperblockFromTask : Pass<"construct-hyperblock-from-task", "func::FuncOp">{
let summary = "Constructs hyperblocks from Taskflow tasks by detecting perfect nested loop bands";
let description = [{
This pass splits tasks so that each task contains exactly one hyperblock.
This creates atomic task units that can be analyzed and optimized independently.

Input: Task with N hyperblocks
Output: N atomic tasks, each containing one hyperblock

This is a prerequisite pass before fusion optimizations.
This pass constructs hyperblocks from Taskflow tasks by detecting perfect nested loop bands.
}];
let constructor = "taskflow::createCanonicalizeTaskPass()";
let constructor = "taskflow::createConstructHyperblockFromTaskPass()";
}

def ClassifyCounters : Pass<"classify-counters", "ModuleOp">{
Expand Down
109 changes: 83 additions & 26 deletions lib/Conversion/AffineToTaskflow/AffineToTaskflowPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
Expand All @@ -30,7 +31,6 @@ namespace {
//------------------------------------------------------------------------------
// Helper Functions.
//------------------------------------------------------------------------------

// Collects memrefs that are loaded (read) within a given operation scope.
static void collectReadMemrefs(Operation *op, SetVector<Value> &read_memrefs) {
op->walk([&](Operation *nested_op) {
Expand Down Expand Up @@ -104,15 +104,41 @@ updateOperationOperands(Operation *op,
}
}

//------------------------------------------------------------------------------
// Analyzes all the original memory access info before conversion.
//------------------------------------------------------------------------------
struct MemrefAccessInfo {
SetVector<Value> read_memrefs;
SetVector<Value> write_memrefs;
};

static DenseMap<Operation *, MemrefAccessInfo>
analyzeMemrefAccesses(func::FuncOp func_op) {
DenseMap<Operation *, MemrefAccessInfo> loop_to_memref_info;

func_op.walk([&](affine::AffineForOp for_op) {
llvm::errs() << "\nAnalyzing memref accesses for loop:\n" << for_op << "\n";
MemrefAccessInfo access_info;

collectReadMemrefs(for_op.getOperation(), access_info.read_memrefs);
collectWrittenMemrefs(for_op.getOperation(), access_info.write_memrefs);

loop_to_memref_info[for_op] = access_info;
});

return loop_to_memref_info;
}

//------------------------------------------------------------------------------
// Task Conversion
//------------------------------------------------------------------------------

// Converts a top-level affine.for to a taskflow.task operation.
static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,
affine::AffineForOp for_op,
DenseMap<Value, Value> &value_mapping,
int task_id) {
static TaskflowTaskOp convertLoopToTask(
OpBuilder &builder, affine::AffineForOp for_op,
DenseMap<Value, Value> &value_mapping,
const DenseMap<Operation *, MemrefAccessInfo> &loop_to_original_memref_info,
int task_id) {
Location loc = for_op.getLoc();
std::string task_name = "Task_" + std::to_string(task_id);

Expand All @@ -125,33 +151,35 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,
// Step 1: Collects read and written memrefs.
//-------------------------------------------------------------------
SetVector<Value> read_memrefs;
SetVector<Value> written_memrefs;
SetVector<Value> write_memrefs;
collectReadMemrefs(for_op.getOperation(), read_memrefs);
collectWrittenMemrefs(for_op.getOperation(), written_memrefs);
collectWrittenMemrefs(for_op.getOperation(), write_memrefs);

llvm::errs() << "Read memrefs for loop:\n" << for_op << "\n";
for (Value memref : read_memrefs) {
llvm::errs() << memref << "\n";
}

llvm::errs() << "Written memrefs for loop:\n" << for_op << "\n";
for (Value memref : written_memrefs) {
for (Value memref : write_memrefs) {
llvm::errs() << memref << "\n";
}

// Collects original memref access info.
auto it = loop_to_original_memref_info.find(for_op.getOperation());
assert(it != loop_to_original_memref_info.end() &&
"Original memref access info not found for the loop");
const MemrefAccessInfo &original_memref_info = it->second;
SetVector<Value> original_read_memrefs = original_memref_info.read_memrefs;
SetVector<Value> original_write_memrefs = original_memref_info.write_memrefs;

//-------------------------------------------------------------------
// Step 2: Determines memory inputs and outputs.
//-------------------------------------------------------------------
// Memory inputs: ALL memrefs that are accessed (read OR written).
// This ensures WAR and WAW dependencies are respected.
SetVector<Value> accessed_memrefs;
accessed_memrefs.insert(read_memrefs.begin(), read_memrefs.end());
accessed_memrefs.insert(written_memrefs.begin(), written_memrefs.end());

// Memory outputs: ONLY memrefs that are written.
// This ensures RAW and WAW dependencies are respected.
SetVector<Value> output_memrefs;
output_memrefs.insert(written_memrefs.begin(), written_memrefs.end());
output_memrefs.insert(write_memrefs.begin(), write_memrefs.end());

//-------------------------------------------------------------------
// Step 3: Collects external SSA values (non-memref).
Expand All @@ -167,17 +195,28 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,
//-------------------------------------------------------------------
// Step 4: Resolves inputs through value mapping.
//-------------------------------------------------------------------
SmallVector<Value> memory_inputs;
SmallVector<Value> read_inputs;
SmallVector<Value> write_inputs;
SmallVector<Value> value_inputs;
IRMapping mapping;

// Resolves memory inputs.
for (Value memref : accessed_memrefs) {
// Resolves read inputs.
for (Value memref : read_memrefs) {
Value resolved_memref = value_mapping.lookup(memref);
if (!resolved_memref) {
resolved_memref = memref;
}
read_inputs.push_back(resolved_memref);
mapping.map(memref, resolved_memref);
}

// Resolves write inputs.
for (Value memref : write_memrefs) {
Value resolved_memref = value_mapping.lookup(memref);
if (!resolved_memref) {
resolved_memref = memref;
}
memory_inputs.push_back(resolved_memref);
write_inputs.push_back(resolved_memref);
mapping.map(memref, resolved_memref);
}

Expand Down Expand Up @@ -211,9 +250,12 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,
loc,
/*memory_outputs=*/memory_output_types,
/*value_outputs=*/value_output_types,
/*memory_inputs=*/memory_inputs,
/*read_inputs=*/read_inputs,
/*write_inputs=*/write_inputs,
/*value_inputs=*/value_inputs,
/*task_name=*/builder.getStringAttr(task_name));
/*task_name=*/builder.getStringAttr(task_name),
/*original_read_memrefs=*/original_read_memrefs.getArrayRef(),
/*original_write_memrefs=*/original_write_memrefs.getArrayRef());

//-------------------------------------------------------------------
// Step 7: Builds the task body.
Expand All @@ -223,8 +265,15 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,

// Adds block arguments (memory inputs first, then value inputs).
DenseMap<Value, BlockArgument> input_to_block_arg;
// Memory input arguments.
for (Value memref : accessed_memrefs) {
// Memory read input arguments.
for (Value memref : read_memrefs) {
BlockArgument arg = task_body->addArgument(memref.getType(), loc);
mapping.map(memref, arg);
input_to_block_arg[memref] = arg;
}

// Memory write input arguments.
for (Value memref : write_memrefs) {
BlockArgument arg = task_body->addArgument(memref.getType(), loc);
mapping.map(memref, arg);
input_to_block_arg[memref] = arg;
Expand Down Expand Up @@ -270,7 +319,7 @@ static TaskflowTaskOp convertLoopToTask(OpBuilder &builder,
//-------------------------------------------------------------------
// Memory outputs.
for (auto [memref, task_output] :
llvm::zip(output_memrefs, task_op.getMemoryOutputs())) {
llvm::zip(output_memrefs, task_op.getWriteOutputs())) {
value_mapping[memref] = task_output;
}

Expand All @@ -285,6 +334,8 @@ static LogicalResult convertFuncToTaskflow(func::FuncOp func_op) {

llvm::errs() << "\n===Converting function: " << func_op.getName() << "===\n";

DenseMap<Operation *, MemrefAccessInfo> loop_to_original_memref_info =
analyzeMemrefAccesses(func_op);
OpBuilder builder(func_op.getContext());
SmallVector<affine::AffineForOp> loops_to_erase;
DenseMap<Value, Value> value_mapping;
Expand All @@ -298,13 +349,19 @@ static LogicalResult convertFuncToTaskflow(func::FuncOp func_op) {
ops_to_process.push_back(&op);
}

llvm::errs() << "ops_to_process:\n";
for (Operation *op : ops_to_process) {
llvm::errs() << *op << "\n";
}

// Processes each operation in order (top to bottom).
for (Operation *op : ops_to_process) {
if (auto for_op = dyn_cast<affine::AffineForOp>(op)) {
// Converts affine.for to taskflow.task.
OpBuilder builder(for_op);
TaskflowTaskOp task_op = convertLoopToTask(
builder, for_op, value_mapping, task_id_counter++);
TaskflowTaskOp task_op =
convertLoopToTask(builder, for_op, value_mapping,
loop_to_original_memref_info, task_id_counter++);

// Replaces uses of loop results with task value outputs.
for (auto [loop_result, task_value_output] :
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TaskflowToNeura/TaskflowToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ struct HyperblockToKernelPattern
return failure();
}

// Asserts that each task contains only one hyperblock.
int hyperblock_count = 0;
task_op.walk([&](TaskflowHyperblockOp op) { hyperblock_count++; });
assert(hyperblock_count == 1 &&
"Each taskflow.task should contain only one hyperblock");

Block &hb_block = hyperblock_op.getBody().front();
Block &task_block = task_op.getBody().front();

Expand Down
Loading