Skip to content

Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes#1094

Open
tgymnich wants to merge 4 commits into
mainfrom
tim/vector_shape-consistency
Open

Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes#1094
tgymnich wants to merge 4 commits into
mainfrom
tim/vector_shape-consistency

Conversation

@tgymnich
Copy link
Copy Markdown
Contributor

@tgymnich tgymnich commented Mar 10, 2026

  • change waves_per_block to DenseI32ArrayAttr to allow to use struct(params) in assemblyFormat
  • verify vector_shapes consistency with WorkgroupConstraint and TilingConstraint

fixes #864

@tgymnich tgymnich changed the title tim/vector shape consistency Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes Mar 10, 2026
@ftynse
Copy link
Copy Markdown
Contributor

ftynse commented Mar 10, 2026

Fly-by on a draft, please document your understanding of what vector_shapes are and how they relate to other things in docs/ir_design.rst.

…rams) in assemblyFormat

Signed-off-by: Tim Gymnich <tim@gymni.ch>
…onstraint

Signed-off-by: Tim Gymnich <tim@gymni.ch>
Signed-off-by: Tim Gymnich <tim@gymni.ch>
@tgymnich tgymnich force-pushed the tim/vector_shape-consistency branch from 5fc38c2 to eed86ef Compare March 12, 2026 16:19
@tgymnich tgymnich marked this pull request as ready for review March 12, 2026 16:19
@github-actions
Copy link
Copy Markdown

Water Code Coverage

Filename                                                           Functions  Missed Functions  Executed       Lines      Missed Lines     Cover    Branches   Missed Branches     Cover
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
lib/Transforms/MemrefDecomposition.cpp                                    28                 0   100.00%         600                49    91.83%         104                46    55.77%
lib/Transforms/AllocToAlloca.cpp                                           2                 0   100.00%          17                 0   100.00%           0                 0         -
lib/Transforms/CheckStaticAssertions.cpp                                   2                 0   100.00%          22                 1    95.45%           8                 4    50.00%
lib/Transforms/GPUModuleToBinary.cpp                                      19                 5    73.68%         339               115    66.08%         128                57    55.47%
lib/Transforms/DropTransformOps.cpp                                        2                 0   100.00%          16                 0   100.00%           2                 0   100.00%
lib/Transforms/GPUToGPURuntime.cpp                                        14                 0   100.00%         298                23    92.28%          40                17    57.50%
lib/Transforms/SLPVectorizer.cpp                                          61                 3    95.08%        1065                99    90.70%         558               166    70.25%
lib/Transforms/AccessCheckers.cpp                                         35                 1    97.14%         446                40    91.03%         124                30    75.81%
lib/Transforms/AssembleISA.cpp                                             4                 1    75.00%          30                 2    93.33%           2                 1    50.00%
lib/Dialect/Wave/Transforms/LoweringPatterns.cpp                          45                 2    95.56%         916               142    84.50%         264                79    70.08%
lib/Dialect/Wave/Transforms/PropagateDefaultsFromConstraints.cpp           3                 3     0.00%          35                35     0.00%          12                12     0.00%
lib/Dialect/Wave/Transforms/TypeConverter.cpp                              7                 2    71.43%          96                26    72.92%          32                17    46.88%
lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp                         10                 0   100.00%         238                18    92.44%          58                11    81.03%
lib/Dialect/Wave/Transforms/DetectNormalForms.cpp                          4                 0   100.00%          51                 0   100.00%           8                 0   100.00%
lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp                   2                 0   100.00%          23                 1    95.65%           6                 1    83.33%
lib/Dialect/Wave/Transforms/InferTypes.cpp                                97                14    85.57%        1554               139    91.06%         816               426    47.79%
lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp                            5                 0   100.00%         129                 1    99.22%          16                 2    87.50%
lib/Dialect/Wave/Transforms/Utils.cpp                                      5                 0   100.00%          85                 5    94.12%          22                 4    81.82%
lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp              7                 0   100.00%         183                16    91.26%          32                14    56.25%
lib/Dialect/Wave/IR/WaveOps.cpp                                          117                11    90.60%        2020               220    89.11%         836               177    78.83%
lib/Dialect/Wave/IR/WaveAttrs.cpp                                         72                 7    90.28%         917                94    89.75%         406                61    84.98%
lib/Dialect/Wave/IR/IndexExpr.cpp                                         11                 0   100.00%         119                 1    99.16%          24                 3    87.50%
lib/Dialect/Wave/IR/WaveDialect.cpp                                       13                 0   100.00%         473                24    94.93%         172                15    91.28%
lib/Dialect/Wave/IR/WaveTypes.cpp                                          9                 1    88.89%          75                 8    89.33%          18                 3    83.33%
lib/Dialect/Wave/IR/WaveInterfaces.cpp                                    81                 3    96.30%        1164                45    96.13%         482                56    88.38%
lib/Dialect/Wave/IR/WaveUtils.cpp                                          7                 0   100.00%         129                 7    94.57%          52                10    80.77%
lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp                3                 0   100.00%          34                 6    82.35%           8                 2    75.00%
lib/Dialect/NormalForm/IR/NormalFormDialect.cpp                            1                 0   100.00%           6                 0   100.00%           0                 0         -
lib/Dialect/NormalForm/IR/NormalFormOps.cpp                               12                 0   100.00%         201                11    94.53%          58                 9    84.48%
lib/Pipelines/Pipelines.cpp                                                2                 0   100.00%          27                 0   100.00%           0                 0         -
lib/Analysis/InUseForSpeculation.cpp                                      12                 1    91.67%         142                 8    94.37%          32                 4    87.50%
include/water/Dialect/Wave/Transforms/LoweringPatterns.h                   1                 0   100.00%           3                 0   100.00%           0                 0         -
include/water/Dialect/Wave/IR/IndexExpr.h                                  1                 0   100.00%          10                 0   100.00%           2                 0   100.00%
include/water/Dialect/Wave/IR/WaveInterfaces.h                            37                 3    91.89%         144                 8    94.44%           8                 2    75.00%
include/water/Dialect/Wave/IR/WaveTypes.h                                  1                 0   100.00%           5                 0   100.00%           4                 0   100.00%
include/water/Dialect/Wave/IR/WaveUtils.h                                  1                 0   100.00%           5                 1    80.00%           4                 2    50.00%
include/water/Dialect/Wave/IR/WaveAttrs.h                                  4                 0   100.00%          14                 0   100.00%           0                 0         -
include/water/Dialect/NormalForm/IR/NormalFormInterfaces.h                 1                 1     0.00%           4                 4     0.00%           0                 0         -
include/water/Analysis/InUseForSpeculation.h                              12                 3    75.00%          39                17    56.41%          16                10    37.50%
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TOTAL                                                                    750                61    91.87%       11674              1166    90.01%        4354              1241    71.50%

Download full HTML report

Copy link
Copy Markdown
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Let's iterate on the document to refine and align our understanding. There is a lot of underdocmented and potentially redundant behavior there that we need to iron out.

If you are using an agent for code understanding, make sure to instruct it to only look at python code (mine otherwise gives prioity to the better-documented C++/MLIR) and specifically say the usage may be inconsistent and that it should infer the behavior from code rather than believe the comments.

Comment thread docs/ir_design.rst
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a mess, but we alredy have docs/wave/ir_design_notes.rst, could we keep everything there?

Comment thread docs/ir_design.rst
``vector_shapes`` is an optional ``DictionaryAttr`` on
``#wave.hardware_constraint``. Each entry maps a dimension name (a string
matching a ``#wave.symbol``) to an integer specifying how many elements a
single wave processes along that dimension in one instance of an operation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it per wave or per workgroup? I keep being confused by it. Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave?

Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave

Yes, modulo potential masking that may or may not happen.

Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?

No, vector_shapes can be of any size (* it makes sense for the size to be at least a multiple of threads_per_wave), since element wise operations are not really limited to fixed sizes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, masking happens for read/write and the replacement element is 0, which may be problematic if the value is then used as RHS of a division or a modulo... Maybe let's emit a warning (beware that you may have to use emitWarning(op->getLoc()) instead of op->emitWarning due to a verification order bug) when it is not divisible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave.

Comment thread docs/ir_design.rst
* determine how many elements each thread processes (``elements_per_thread``),
* compute memory access strides, and
* drive the expansion (unrolling) pass that replicates operations until the
workgroup tile is covered.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workgroup or wave?

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave. And workgroup tile should rather be block.
waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave. And workgroup tile should rather be block.

Workgroup is an alias for block. We should try to use AMD terminology consistently.

waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint

Then we should have a check for them matching or, if this doesn't break any existing functionality, that only one of the two mechanisms is present.

Comment thread docs/ir_design.rst

There are two cases, depending on whether ``mma_type`` is present.

**When mma_type is set,** ``vector_shapes`` is derived from the MMA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is worth noting that indicidual mma instruction may override the mma kind provided in hardware constraints...

Comment thread docs/ir_design.rst

Additional entries may be provided for dimensions the MMA analysis does not
cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit
``vector_shapes`` coexist.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code overwrites mma shapes by hw shapes AFAICS:

if hardware_constraint.vector_shapes:
custom.vector_shapes.update(hardware_constraint.vector_shapes)

(and the coding agent doesn't see it unless you shove its nose in it).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being resolved in #1141

Comment thread docs/ir_design.rst
reason about the per-wave tile.

**When mma_type is absent,** the MLIR verifier enforces that each
``vector_shapes`` entry **matches** the resolved tile size from the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced that vector shapes must match workgroup tile size. Not saying I'm right, but please convince me one way or another.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the presence of WaveConstraint it must match that. Since element wise ops can operate on any vector size, we don't need unrolling like we do for mma.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this joins the thread above: vector shapes and wave constraints are redundant. I think this has more to do with the redundancy, and less to do with the kind of ops we have. At HW level, elementwise ops should still operate on 1/2/4 elements, it's only a question of which level of the stack does the unrolling: wave, mlir or llvm.

Comment thread docs/ir_design.rst
Comment on lines +104 to +105
that dimension. Unlike with mma_operations that have a fixed size, element wise operations
can operate on any number of elements_per_thread and thus don't need to be expanded multiple times.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dosen't sound right. None of the operations may actually opeate on arbirary number of elements per thread, there are instructions that usually work on 1, 2 or 4 elements per thread, sometimes depending on element size. Hence the expansion process. It may just replicate some operations more than some others, depending on the "native" size they support.

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, on the hardware level you are right. But from how PyWave lowers to mlir, we can just have an operation like %2 = math.exp %arg0 : vector<4xf32> operating on an arbitrary number of vector elements, which would not work with mma. MMA seems to be the only kind of operation like that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop this one would. Same as above, this is more about where we choose to do the unrolling. Let's double check whether this happens in wave expansion or not. I'd expect it to happen.

Comment thread docs/ir_design.rst

This means that in non-MMA programs, there is no separate expansion step:
``vector_shapes`` equals the tile size and each operation appears exactly once
per dimension.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct.

Copy link
Copy Markdown
Contributor Author

@tgymnich tgymnich Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might be in pywave... But this line relates to the verifier change is just made which forbids this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't have a verifier that forbids things we can express in pywave, unless as the last resort for unsound semantics.

Comment thread docs/ir_design.rst
hardware constraint. When ``kind`` is absent, the
``PropagateDefaultsFromConstraints`` pass fills it from the hardware
constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same
function, each carries its own ``kind`` and its own effective vector shapes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carries the vector shapes where?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_attrs: Additional attributes to set on the fx_node after creation (e.g., index, vector_shapes). These are not passed to the dataclass constructor.

Inside of extra_attrs in CustomOp

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this paragraph seems to be about MLIR. And it doesn't carry vector shapes AFAIK. I may end up adding it though.

Comment thread docs/ir_design.rst
``elements_per_thread`` is related to ``vector_shapes`` conceptually: the
vector shape for a dimension gives the total elements a wave handles, and
dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for
thread count per workgroup dimension gives the per-thread count. The
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't parse this. My understanding is that EPT should be directly related: we know the number of elements per workgroup, there's a waves_per_block in hw constraints (which yet again may be redundant with both wave constraints and vector_shapes), which should tells us the number of elements per wave. We also know the number of threads in a wave, which should give us EPT. What happens when operation-specified EPT doesn't match the one that would be inferred by the process above -- I don't know.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is conceptually and directly related :) As everything in wave.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I still can't parse what is says. Consider rephrasing.

and dividing by threads_per_wave (for a reduction dimension)

Why specifically for a reduction dimension? Threads per wave is a scalar, how do we divide a dictionary per that?

accounting for thread count per workgroup dimension

Accounting how exactly? Is this different from division above?

Also, thread count per workgroup dimension subsumes thread count per wave.

Signed-off-by: Tim Gymnich <tim@gymni.ch>
Copy link
Copy Markdown
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing check looks right, we need one more, and please update the document to make it less confusing.

Comment thread docs/ir_design.rst
``vector_shapes`` is an optional ``DictionaryAttr`` on
``#wave.hardware_constraint``. Each entry maps a dimension name (a string
matching a ``#wave.symbol``) to an integer specifying how many elements a
single wave processes along that dimension in one instance of an operation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave.

Comment thread docs/ir_design.rst
``#wave.hardware_constraint``. Each entry maps a dimension name (a string
matching a ``#wave.symbol``) to an integer specifying how many elements a
single wave processes along that dimension in one instance of an operation
before expansion has replicated it.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
before expansion has replicated it.
after the expansion process has replicated it.

Comment thread docs/ir_design.rst
* determine how many elements each thread processes (``elements_per_thread``),
* compute memory access strides, and
* drive the expansion (unrolling) pass that replicates operations until the
workgroup tile is covered.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
workgroup tile is covered.
wave tile is covered.

Comment thread docs/ir_design.rst

Additional entries may be provided for dimensions the MMA analysis does not
cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit
``vector_shapes`` coexist.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being resolved in #1141

}

// Verify consistency between constraints and vector_shapes (when mma_type
// is absent). Each vector_shapes entry must match the resolved tile size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// is absent). Each vector_shapes entry must match the resolved tile size
// is absent). Each vector_shapes entry must be less than or equal to the resolved tile size

}
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want a check what waves_per_block in hardware constraint == evaluated workgroup tile size / evaluated wave tile size. And it may make sense to add an error or warning if one tile size is not divisible by the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes

2 participants