Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes#1094
Add MLIR verifier to check HardwareConstraint.vector_shapes consistency with tile sizes#1094tgymnich wants to merge 4 commits into
Conversation
|
Fly-by on a draft, please document your understanding of what |
…rams) in assemblyFormat Signed-off-by: Tim Gymnich <tim@gymni.ch>
…onstraint Signed-off-by: Tim Gymnich <tim@gymni.ch>
5fc38c2 to
eed86ef
Compare
Water Code Coverage |
ftynse
left a comment
There was a problem hiding this comment.
Thanks! Let's iterate on the document to refine and align our understanding. There is a lot of underdocmented and potentially redundant behavior there that we need to iron out.
If you are using an agent for code understanding, make sure to instruct it to only look at python code (mine otherwise gives prioity to the better-documented C++/MLIR) and specifically say the usage may be inconsistent and that it should infer the behavior from code rather than believe the comments.
There was a problem hiding this comment.
It's a bit of a mess, but we alredy have docs/wave/ir_design_notes.rst, could we keep everything there?
| ``vector_shapes`` is an optional ``DictionaryAttr`` on | ||
| ``#wave.hardware_constraint``. Each entry maps a dimension name (a string | ||
| matching a ``#wave.symbol``) to an integer specifying how many elements a | ||
| single wave processes along that dimension in one instance of an operation |
There was a problem hiding this comment.
Is it per wave or per workgroup? I keep being confused by it. Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave?
Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?
There was a problem hiding this comment.
Also, does this imply that the product of values in vector_shapes must be a multipe of threads_per_wave
Yes, modulo potential masking that may or may not happen.
Similarly, there is some implicit notion of "hardware-compatible" sizes, like the ones from the mma kind, but also read/write widths. Does this mean that vector shapes should be a multiple of those?
No, vector_shapes can be of any size (* it makes sense for the size to be at least a multiple of threads_per_wave), since element wise operations are not really limited to fixed sizes.
There was a problem hiding this comment.
AFAIK, masking happens for read/write and the replacement element is 0, which may be problematic if the value is then used as RHS of a division or a modulo... Maybe let's emit a warning (beware that you may have to use emitWarning(op->getLoc()) instead of op->emitWarning due to a verification order bug) when it is not divisible.
There was a problem hiding this comment.
After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave.
| * determine how many elements each thread processes (``elements_per_thread``), | ||
| * compute memory access strides, and | ||
| * drive the expansion (unrolling) pass that replicates operations until the | ||
| workgroup tile is covered. |
There was a problem hiding this comment.
wave. And workgroup tile should rather be block.
waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint
There was a problem hiding this comment.
wave. And workgroup tile should rather be block.
Workgroup is an alias for block. We should try to use AMD terminology consistently.
waves_per_block also models the same concept as WorkgroupConstraint and WaveConstraint
Then we should have a check for them matching or, if this doesn't break any existing functionality, that only one of the two mechanisms is present.
|
|
||
| There are two cases, depending on whether ``mma_type`` is present. | ||
|
|
||
| **When mma_type is set,** ``vector_shapes`` is derived from the MMA |
There was a problem hiding this comment.
It is worth noting that indicidual mma instruction may override the mma kind provided in hardware constraints...
|
|
||
| Additional entries may be provided for dimensions the MMA analysis does not | ||
| cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit | ||
| ``vector_shapes`` coexist. |
There was a problem hiding this comment.
The code overwrites mma shapes by hw shapes AFAICS:
wave/wave_lang/kernel/wave/utils/mma_utils.py
Lines 141 to 142 in f591a21
(and the coding agent doesn't see it unless you shove its nose in it).
| reason about the per-wave tile. | ||
|
|
||
| **When mma_type is absent,** the MLIR verifier enforces that each | ||
| ``vector_shapes`` entry **matches** the resolved tile size from the |
There was a problem hiding this comment.
I'm not convinced that vector shapes must match workgroup tile size. Not saying I'm right, but please convince me one way or another.
There was a problem hiding this comment.
In the presence of WaveConstraint it must match that. Since element wise ops can operate on any vector size, we don't need unrolling like we do for mma.
There was a problem hiding this comment.
Okay, so this joins the thread above: vector shapes and wave constraints are redundant. I think this has more to do with the redundancy, and less to do with the kind of ops we have. At HW level, elementwise ops should still operate on 1/2/4 elements, it's only a question of which level of the stack does the unrolling: wave, mlir or llvm.
| that dimension. Unlike with mma_operations that have a fixed size, element wise operations | ||
| can operate on any number of elements_per_thread and thus don't need to be expanded multiple times. |
There was a problem hiding this comment.
This dosen't sound right. None of the operations may actually opeate on arbirary number of elements per thread, there are instructions that usually work on 1, 2 or 4 elements per thread, sometimes depending on element size. Hence the expansion process. It may just replicate some operations more than some others, depending on the "native" size they support.
There was a problem hiding this comment.
Well, on the hardware level you are right. But from how PyWave lowers to mlir, we can just have an operation like %2 = math.exp %arg0 : vector<4xf32> operating on an arbitrary number of vector elements, which would not work with mma. MMA seems to be the only kind of operation like that.
There was a problem hiding this comment.
https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop this one would. Same as above, this is more about where we choose to do the unrolling. Let's double check whether this happens in wave expansion or not. I'd expect it to happen.
|
|
||
| This means that in non-MMA programs, there is no separate expansion step: | ||
| ``vector_shapes`` equals the tile size and each operation appears exactly once | ||
| per dimension. |
There was a problem hiding this comment.
I don't think this is correct.
There was a problem hiding this comment.
there might be in pywave... But this line relates to the verifier change is just made which forbids this.
There was a problem hiding this comment.
We shouldn't have a verifier that forbids things we can express in pywave, unless as the last resort for unsound semantics.
| hardware constraint. When ``kind`` is absent, the | ||
| ``PropagateDefaultsFromConstraints`` pass fills it from the hardware | ||
| constraint's ``mma_type``. When multiple ``wave.mma`` ops exist in the same | ||
| function, each carries its own ``kind`` and its own effective vector shapes. |
There was a problem hiding this comment.
Carries the vector shapes where?
There was a problem hiding this comment.
extra_attrs: Additional attributes to set on the fx_node after creation (e.g.,index,vector_shapes). These are not passed to the dataclass constructor.
Inside of extra_attrs in CustomOp
There was a problem hiding this comment.
Yes, but this paragraph seems to be about MLIR. And it doesn't carry vector shapes AFAIK. I may end up adding it though.
| ``elements_per_thread`` is related to ``vector_shapes`` conceptually: the | ||
| vector shape for a dimension gives the total elements a wave handles, and | ||
| dividing by ``threads_per_wave`` (for a reduction dimension) or accounting for | ||
| thread count per workgroup dimension gives the per-thread count. The |
There was a problem hiding this comment.
I can't parse this. My understanding is that EPT should be directly related: we know the number of elements per workgroup, there's a waves_per_block in hw constraints (which yet again may be redundant with both wave constraints and vector_shapes), which should tells us the number of elements per wave. We also know the number of threads in a wave, which should give us EPT. What happens when operation-specified EPT doesn't match the one that would be inferred by the process above -- I don't know.
There was a problem hiding this comment.
it is conceptually and directly related :) As everything in wave.
There was a problem hiding this comment.
But I still can't parse what is says. Consider rephrasing.
and dividing by
threads_per_wave(for a reduction dimension)
Why specifically for a reduction dimension? Threads per wave is a scalar, how do we divide a dictionary per that?
accounting for thread count per workgroup dimension
Accounting how exactly? Is this different from division above?
Also, thread count per workgroup dimension subsumes thread count per wave.
Signed-off-by: Tim Gymnich <tim@gymni.ch>
ftynse
left a comment
There was a problem hiding this comment.
The existing check looks right, we need one more, and please update the document to make it less confusing.
| ``vector_shapes`` is an optional ``DictionaryAttr`` on | ||
| ``#wave.hardware_constraint``. Each entry maps a dimension name (a string | ||
| matching a ``#wave.symbol``) to an integer specifying how many elements a | ||
| single wave processes along that dimension in one instance of an operation |
There was a problem hiding this comment.
After a live discussion, this appears to be an additional tiling level, so vector shape is per operation instance inside wave, not just one wave.
| ``#wave.hardware_constraint``. Each entry maps a dimension name (a string | ||
| matching a ``#wave.symbol``) to an integer specifying how many elements a | ||
| single wave processes along that dimension in one instance of an operation | ||
| before expansion has replicated it. |
There was a problem hiding this comment.
| before expansion has replicated it. | |
| after the expansion process has replicated it. |
| * determine how many elements each thread processes (``elements_per_thread``), | ||
| * compute memory access strides, and | ||
| * drive the expansion (unrolling) pass that replicates operations until the | ||
| workgroup tile is covered. |
There was a problem hiding this comment.
| workgroup tile is covered. | |
| wave tile is covered. |
|
|
||
| Additional entries may be provided for dimensions the MMA analysis does not | ||
| cover (e.g. a batch dimension), and in that case both ``mma_type`` and explicit | ||
| ``vector_shapes`` coexist. |
| } | ||
|
|
||
| // Verify consistency between constraints and vector_shapes (when mma_type | ||
| // is absent). Each vector_shapes entry must match the resolved tile size |
There was a problem hiding this comment.
| // is absent). Each vector_shapes entry must match the resolved tile size | |
| // is absent). Each vector_shapes entry must be less than or equal to the resolved tile size |
| } | ||
| } | ||
| } | ||
|
|
There was a problem hiding this comment.
We also want a check what waves_per_block in hardware constraint == evaluated workgroup tile size / evaluated wave tile size. And it may make sense to add an error or warning if one tile size is not divisible by the other.
fixes #864