diff --git a/docs/reference/security_advisory_subroutine_hash_collision.rst b/docs/reference/security_advisory_subroutine_hash_collision.rst new file mode 100644 index 000000000000..0de960a2b180 --- /dev/null +++ b/docs/reference/security_advisory_subroutine_hash_collision.rst @@ -0,0 +1,84 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _security-advisory-subroutine-hash-collision: + +Security Advisory: Subroutine Cache Hash Collision +=================================================== + +Summary +------- + +``SubroutineMixin._get_subroutine()`` in ``python/tvm/relax/frontend/nn/subroutine.py`` +used ``ir.structural_hash`` as the sole cache lookup key without a subsequent +``structural_equal`` verification. If two different ``arg_sinfo`` values produced the +same 64-bit hash, the cache would return a previously compiled function with +mismatched parameter shapes, leading to silently incorrect compiled output. + +Severity +-------- + +**Low.** The ``structural_hash`` function returns a 64-bit integer. A natural hash +collision requires approximately 2^32 distinct inputs (birthday bound), making +accidental collision extremely unlikely in normal compilation workflows. The issue +is primarily a **correctness defect** rather than a practically exploitable security +vulnerability. + +Affected Code +------------- + +- **File**: ``python/tvm/relax/frontend/nn/subroutine.py`` +- **Method**: ``SubroutineMixin._get_subroutine()`` +- **Trigger condition**: ``define_subroutine = True`` on an ``nn.Module`` subclass, + with two or more calls using different input shapes within the same compilation session. + +Root Cause +---------- + +The subroutine cache (``cls._gvar``) was keyed by +``(structural_hash(arg_sinfo, map_free_vars=True), is_dataflow)``. +A hash match was treated as proof of structural equality, skipping the necessary +``structural_equal`` check. This is inconsistent with the pattern used elsewhere in +TVM (e.g., ``block_builder.cc`` uses ``StructuralHash`` + ``StructuralEqual`` together +in ``std::unordered_map``). + +Impact +------ + +If a collision occurred: + +1. The cache returned a ``GlobalVar`` bound to a function compiled for a different + input shape. +2. The caller would invoke this wrong function with mismatched arguments. +3. The compiled Relax IR module would contain an incorrect function call. +4. At inference time, the model would produce wrong numerical results **without + any error or warning**. + +Fix +--- + +The cache now stores a list of ``(arg_sinfo, result)`` pairs per hash bucket. +On lookup, each candidate is verified with ``structural_equal`` before returning. +This follows the standard hash-table pattern: hash for bucket selection, equality +for final verification. + +Recommendations +--------------- + +- Update to the patched version of TVM. +- If you maintain custom code that caches TVM IR nodes by ``structural_hash``, + ensure that a ``structural_equal`` check is always performed on cache hits. diff --git a/python/tvm/relax/frontend/nn/subroutine.py b/python/tvm/relax/frontend/nn/subroutine.py index c62491be9ff5..0851d1e81ea3 100644 --- a/python/tvm/relax/frontend/nn/subroutine.py +++ b/python/tvm/relax/frontend/nn/subroutine.py @@ -25,6 +25,7 @@ import typing from tvm import ir, relax +from tvm.ir import structural_equal from tvm.relax.frontend import nn @@ -143,8 +144,9 @@ def _get_subroutine( is_dataflow = block_builder.current_block_is_dataflow() lookup_key = (ir.structural_hash(arg_sinfo, map_free_vars=True), is_dataflow) - if lookup_key in cls._gvar: - return cls._gvar[lookup_key] + for cached_sinfo, cached_result in cls._gvar.get(lookup_key, []): + if structural_equal(cached_sinfo, arg_sinfo, map_free_vars=True): + return cached_result func_name = _camel_to_snake(cls.__name__) func_params = [relax.Var(name, sinfo) for name, sinfo in zip(func_args, arg_sinfo.fields)] @@ -175,5 +177,7 @@ def _get_subroutine( mod = block_builder.get() mod.update_func(gvar, relax.utils.copy_with_new_vars(mod[gvar])) - cls._gvar[lookup_key] = (gvar, is_nn_tensor_output) - return cls._gvar[lookup_key] + result = (gvar, is_nn_tensor_output) + bucket = cls._gvar.setdefault(lookup_key, []) + bucket.append((arg_sinfo, result)) + return result diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index 9ea44781b8bc..63e9be8fb4b6 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -97,5 +97,55 @@ def activation( assert_structural_equal(Expected, tvm_mod, True) +def test_different_shapes_produce_distinct_subroutines(): + """Regression test: same Module class with different input shapes + must generate distinct subroutines, not reuse a cached one.""" + + class Activation(nn.Module): + define_subroutine = True + + def forward(self, state: relax.Expr) -> relax.Var: + return nn.op.silu(state) + + class Model(nn.Module): + def __init__(self): + self.act_a = Activation() + self.act_b = Activation() + + def forward(self, x: relax.Expr, y: relax.Expr) -> relax.Var: + a = self.act_a(x) + b = self.act_b(y) + return nn.op.add(a, b) + + mod = Model() + batch_size = tvm.tirx.Var("batch_size", "int64") + tvm_mod, _ = mod.export_tvm( + spec={ + "forward": { + "x": nn.spec.Tensor((batch_size, 32), "float32"), + "y": nn.spec.Tensor((batch_size, 64), "float32"), + } + }, + debug=True, + ) + + # Collect all private functions (subroutines) in the module + subroutine_funcs = [ + func + for gvar, func in tvm_mod.functions.items() + if isinstance(func, relax.Function) + and gvar.name_hint not in ( + "forward", + "_initialize_effect", + ) + ] + + # There must be two distinct activation subroutines (one for dim=32, one for dim=64), + # not a single cached one reused for both. + assert len(subroutine_funcs) == 2, ( + f"Expected 2 distinct subroutines for different input shapes, got {len(subroutine_funcs)}" + ) + + if __name__ == "__main__": tvm.testing.main()