-
Notifications
You must be signed in to change notification settings - Fork 333
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cuda.gite2b10c58
System information
3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.gite2b10c58
2.7.0+cu128
Problem description
The LegalizeSafeMemoryAccess transformation produces output that fails TVM's structural equality check, even though the objects appear to be identical. The error occurs when comparing dom.extent nodes that visually look the same.
Reproducible example code
The Python snippets:
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
def issue_1303_buggy_kernel():
num_threads = 256
size_0 = T.dynamic('size_0')
size_1 = T.dynamic('size_1')
@T.prim_func
def main(num_blocks: T.int32, idx_out: T.Tensor[(size_0, size_1), "int32"]):
with T.Kernel(num_blocks, threads=num_threads) as block_idx:
idx_out[block_idx, block_idx] = 0
@T.prim_func
def expected(num_blocks: T.int32, idx_out: T.Tensor[(size_0, size_1), "int32"]):
with T.Kernel(num_blocks, threads=num_threads) as block_idx:
if block_idx < size_1:
if block_idx < size_0:
idx_out[block_idx, block_idx] = 0
return main, expected
def test_issue_1303():
func, expected = issue_1303_buggy_kernel()
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
# print(transformed["main"].body.block.body.node.dom.extent)
# print(expected.body.block.body.node.dom.extent)
tvm.ir.assert_structural_equal(transformed["main"].body.block.body.node.dom.extent, expected.body.block.body.node.dom.extent)
# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
if __name__ == "__main__":
test_issue_1303()Traceback
Traceback (most recent call last):
File "~/qwq/test3.py", line 37, in <module>
test_issue_1303()
File "~/qwq/test3.py", line 32, in test_issue_1303
tvm.ir.assert_structural_equal(transformed["main"].body.block.body.node.dom.extent, expected.body.block.body.node.dom.extent)
File "~/tilelang/3rdparty/tvm/python/tvm/ir/base.py", line 252, in assert_structural_equal
_ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
File "~/tilelang/3rdparty/tvm/src/node/structural_equal.cc", line 68, in bool tvm::NodeStructuralEqualAdapter(const ffi::Any&, const ffi::Any&, bool, bool)
TVM_FFI_THROW(ValueError) << oss.str();
ValueError: StructuralEqual check failed, caused by lhs at <root>:
num_blocks
^^^^^^^^^^
and rhs at <root>:
num_blocks
^^^^^^^^^^Expected behavior
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working