-
Notifications
You must be signed in to change notification settings - Fork 333
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cuda.git422fb129
System information
3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.git422fb129
2.7.0+cu128
Problem description
Layout inference succeeds when reshaping from {1, 7168} to {1, 224}, but fails when reshaping from {16, 448} to {16, 14} with the same reduction pattern (j // 32).
Reproducible example code
The Python snippets:
import tilelang as tl
import tilelang.language as T
tl.disable_cache()
@tl.jit
def get_wrong_kernel(M: int = 4096):
dtype = "int32"
num_threads = 128
@T.prim_func
def main(A: T.Tensor((16, 14), dtype=dtype), B: T.Tensor((16, 448), dtype=dtype)):
with T.Kernel(1, threads=num_threads) as (bx, ):
A_local = T.alloc_fragment((16, 14), dtype=dtype)
B_local = T.alloc_fragment((16, 448), dtype=dtype)
T.copy(A, A_local)
T.copy(B, B_local)
for i, j in T.Parallel(16, 448):
A_local[i, j // 32] += B[i, j]
return main
@tl.jit
def get_correct_kernel(M: int = 4096):
dtype = "int32"
num_threads = 128
@T.prim_func
def main(A: T.Tensor((1, 224), dtype=dtype), B: T.Tensor((1, 7168), dtype=dtype)):
with T.Kernel(1, threads=num_threads) as (bx, ):
A_local = T.alloc_fragment((1, 224), dtype=dtype)
B_local = T.alloc_fragment((1, 7168), dtype=dtype)
T.copy(A, A_local)
T.copy(B, B_local)
for i, j in T.Parallel(1, 7168):
A_local[i, j // 32] += B[i, j]
return main
kernel = get_wrong_kernel()
print(kernel.get_kernel_source())Traceback
Traceback (most recent call last):
File "~/qwq/test5.py", line 44, in <module>
kernel = get_wrong_kernel()
^^^^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/jit/__init__.py", line 287, in __call__
self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/jit/__init__.py", line 238, in compile
kernel_result = compile(
^^^^^^^^
File "~/tilelang/tilelang/jit/__init__.py", line 98, in compile
return cached(
^^^^^^^
File "~/tilelang/tilelang/cache/__init__.py", line 30, in cached
return _kernel_cache_instance.cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/cache/kernel_cache.py", line 159, in cached
return JITKernel(
^^^^^^^^^^
File "~/tilelang/tilelang/jit/kernel.py", line 131, in __init__
adapter = self._compile_and_create_adapter(func, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/jit/kernel.py", line 231, in _compile_and_create_adapter
artifact = tilelang.lower(
^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/engine/lower.py", line 250, in lower
mod = LowerAndLegalize(mod, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/tilelang/tilelang/engine/phase.py", line 123, in LowerAndLegalize
mod = tilelang.transform.LayoutInference()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in std::_Function_handler<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext), tvm::tl::LayoutInference()::{lambda(tvm::tir::PrimFunc, tvm::IRModule const&, tvm::transform::PassContext const&)#1}>::_M_invoke(std::_Any_data const&, tvm::tir::PrimFunc&&, tvm::IRModule&&, tvm::transform::PassContext&&)
File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::Run()
File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::InferInFreeMode(tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void>&, tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void> const&)
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: Check failed: (min_reg_num < INT64_MAX) is false: no available layout foundExpected behavior
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working