Add support for specialization constants#2304
Conversation
|
View rendered docs @ https://intelpython.github.io/dpctl/pulls/2304/index.html |
b7f8d82 to
8c5651d
Compare
8dbb320 to
043aa83
Compare
82736b8 to
1a0a910
Compare
55d4458 to
97871d9
Compare
| @@ -0,0 +1,25 @@ | |||
| # Data Parallel Control (dpctl) | |||
| # | |||
| # Copyright 2020-2025 Intel Corporation | |||
There was a problem hiding this comment.
It seems we need to update the copyright year everywhere in dpctl
| spv_bytes: bytes | bytearray | memoryview, | ||
| ) -> tuple[SpecializationConstantInfo]: | ||
| """ | ||
| Parses SPIR-V byte stream to extract information about specializations, |
There was a problem hiding this comment.
Needs to add rendering documentation for the new functions
|
|
||
|
|
||
| class SpirvOpCode(IntEnum): | ||
| OpName = 5 |
There was a problem hiding this comment.
I wonder if it makes sense to import the interesting constants directly from the SPIRV-Headers?
Or probably to add just a ref comment with a link to the code from where the constants come?
There was a problem hiding this comment.
there's no wheel package on PyPI for SPIRV-headers, which is why this wasn't done. In theory there is already spirv-dis which disassembles SPIRV and can be used to get some of the info, but for that same reason this wasn't done
A link makes sense in the ref though
| } | ||
| } | ||
| else { | ||
| error_handler("clSetProgramSpecializationConstant is not available " |
There was a problem hiding this comment.
Missing to release clProgram on error
|
|
||
| if word_count == 0: | ||
| raise ValueError(f"Invalid SPIR-V instruction at word index {i}") | ||
|
|
There was a problem hiding this comment.
Missing boundary check:
if i + word_count > len(words):
raise ValueError(f"Invalid SPIR-V: instruction at offset {i} extends beyond buffer")| ) | ||
| elif isinstance(args[0], str): | ||
| target_obj = np.ascontiguousarray(args[1], dtype=args[0]) | ||
|
|
There was a problem hiding this comment.
Do we need an explicit error handling here in case of else?
| # attempt to coerce to a numpy array | ||
| target_obj = np.ascontiguousarray(target_obj) | ||
| else: | ||
| raise TypeError( |
There was a problem hiding this comment.
Probably it'd better to move raise TypeError(...) here from line 335, otherwise that branch is unreachable.
| @@ -444,8 +480,22 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx, | |||
| ZeDevice = get_native<ze_be>(SyclDev); | |||
|
|
|||
| // Specialization constants are not supported by DPCTL at the moment | |||
There was a problem hiding this comment.
Is the comment obsolete now?
There was a problem hiding this comment.
yep, good catch
|
|
||
| dtype_str = type_info["dtype"] | ||
| raw_default = defaults.get(target_id) | ||
| default_value = None |
There was a problem hiding this comment.
It seems when raw_default is True or False (from OpSpecConstantTrue/OpSpecConstantFalse at lines 141, 146), the default value is not set in the result. Is that intended?
| the second argument is interpreted as a pointer to the data. | ||
|
|
||
| Note that when constructing from a buffer, the | ||
| :class:`.SpecializationConstant`, shares memory with the original object. |
There was a problem hiding this comment.
We probably need to warn also about object lifetime: if the source object is deleted, the SpecializationConstant holds a dangling pointer.
also removes "v" as a permitted specialization constant intermediate data type, as composite specialization constants are broken into multiple specialization constants, so structs end up passed as a single constant while the program expects multiple, and therefore, doesn't work as intended
also adds spec_id, itemsize, and default_value fields
97871d9 to
e2e4826
Compare
This PR introduces support for specialization constants in dpctl, including both a Cython class
SpecializationConstantfor construction and passing of the constructed class tocreate_kernel_bundle_from_spirvvia a newspecializationskeyword argument.The
SpecializationConstantclass supports multiple constructors, including from Python buffers, a dtype string and a Python buffer (casting to the dtype via NumPy), and a number of bytes and a pointer as integers.Also introduces
dpctl.program.utilswithparse_spirv_specializationsutility function, allowing the user to query a SPIR-V directly from Python.