Skip to content

[BUG] StructuralEqual fails on identical objects in LegalizeSafeMemoryAccess transformation #1343

@LJC00118

Description

@LJC00118

Required prerequisites

What version of TileLang are you using?

0.1.6.post2+cuda.gite2b10c58

System information

3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.gite2b10c58
2.7.0+cu128

Problem description

The LegalizeSafeMemoryAccess transformation produces output that fails TVM's structural equality check, even though the objects appear to be identical. The error occurs when comparing dom.extent nodes that visually look the same.

Reproducible example code

The Python snippets:

from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T


def issue_1303_buggy_kernel():
    num_threads = 256
    size_0 = T.dynamic('size_0')
    size_1 = T.dynamic('size_1')

    @T.prim_func
    def main(num_blocks: T.int32, idx_out: T.Tensor[(size_0, size_1), "int32"]):
        with T.Kernel(num_blocks, threads=num_threads) as block_idx:
            idx_out[block_idx, block_idx] = 0

    @T.prim_func
    def expected(num_blocks: T.int32, idx_out: T.Tensor[(size_0, size_1), "int32"]):
        with T.Kernel(num_blocks, threads=num_threads) as block_idx:
            if block_idx < size_1:
                if block_idx < size_0:
                    idx_out[block_idx, block_idx] = 0

    return main, expected


def test_issue_1303():
    func, expected = issue_1303_buggy_kernel()
    mod = tvm.IRModule({func.attrs["global_symbol"]: func})
    transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
    # print(transformed["main"].body.block.body.node.dom.extent)
    # print(expected.body.block.body.node.dom.extent)
    tvm.ir.assert_structural_equal(transformed["main"].body.block.body.node.dom.extent, expected.body.block.body.node.dom.extent)
    # tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)


if __name__ == "__main__":
    test_issue_1303()

Traceback

Traceback (most recent call last):
  File "~/qwq/test3.py", line 37, in <module>
    test_issue_1303()
  File "~/qwq/test3.py", line 32, in test_issue_1303
    tvm.ir.assert_structural_equal(transformed["main"].body.block.body.node.dom.extent, expected.body.block.body.node.dom.extent)
  File "~/tilelang/3rdparty/tvm/python/tvm/ir/base.py", line 252, in assert_structural_equal
    _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)  # type: ignore # pylint: disable=no-member
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
  File "~/tilelang/3rdparty/tvm/src/node/structural_equal.cc", line 68, in bool tvm::NodeStructuralEqualAdapter(const ffi::Any&, const ffi::Any&, bool, bool)
    TVM_FFI_THROW(ValueError) << oss.str();

ValueError: StructuralEqual check failed, caused by lhs at <root>:
num_blocks
^^^^^^^^^^
and rhs at <root>:
num_blocks
^^^^^^^^^^

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions