Skip to content

[BUG] Layout Inference Fails for Cases Requiring Replication #1374

@LJC00118

Description

@LJC00118

Required prerequisites

What version of TileLang are you using?

0.1.6.post2+cuda.git422fb129

System information

3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.git422fb129
2.7.0+cu128

Problem description

Layout inference succeeds when reshaping from {1, 7168} to {1, 224}, but fails when reshaping from {16, 448} to {16, 14} with the same reduction pattern (j // 32).

Reproducible example code

The Python snippets:

import tilelang as tl
import tilelang.language as T


tl.disable_cache()


@tl.jit
def get_wrong_kernel(M: int = 4096):
    dtype = "int32"
    num_threads = 128

    @T.prim_func
    def main(A: T.Tensor((16, 14), dtype=dtype), B: T.Tensor((16, 448), dtype=dtype)):
        with T.Kernel(1, threads=num_threads) as (bx, ):
            A_local = T.alloc_fragment((16, 14), dtype=dtype)
            B_local = T.alloc_fragment((16, 448), dtype=dtype)
            
            T.copy(A, A_local)
            T.copy(B, B_local)
            for i, j in T.Parallel(16, 448):
                A_local[i, j // 32] += B[i, j]

    return main

@tl.jit
def get_correct_kernel(M: int = 4096):
    dtype = "int32"
    num_threads = 128

    @T.prim_func
    def main(A: T.Tensor((1, 224), dtype=dtype), B: T.Tensor((1, 7168), dtype=dtype)):
        with T.Kernel(1, threads=num_threads) as (bx, ):
            A_local = T.alloc_fragment((1, 224), dtype=dtype)
            B_local = T.alloc_fragment((1, 7168), dtype=dtype)
            
            T.copy(A, A_local)
            T.copy(B, B_local)
            for i, j in T.Parallel(1, 7168):
                A_local[i, j // 32] += B[i, j]

    return main

kernel = get_wrong_kernel()
print(kernel.get_kernel_source())

Traceback

Traceback (most recent call last):
  File "~/qwq/test5.py", line 44, in <module>
    kernel = get_wrong_kernel()
             ^^^^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/jit/__init__.py", line 287, in __call__
    self._kernel_cache[key] = self.compile(*args, **kwargs, **tune_params)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/jit/__init__.py", line 238, in compile
    kernel_result = compile(
                    ^^^^^^^^
  File "~/tilelang/tilelang/jit/__init__.py", line 98, in compile
    return cached(
           ^^^^^^^
  File "~/tilelang/tilelang/cache/__init__.py", line 30, in cached
    return _kernel_cache_instance.cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/cache/kernel_cache.py", line 159, in cached
    return JITKernel(
           ^^^^^^^^^^
  File "~/tilelang/tilelang/jit/kernel.py", line 131, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/jit/kernel.py", line 231, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/engine/lower.py", line 250, in lower
    mod = LowerAndLegalize(mod, target)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/tilelang/tilelang/engine/phase.py", line 123, in LowerAndLegalize
    mod = tilelang.transform.LayoutInference()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in std::_Function_handler<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext), tvm::tl::LayoutInference()::{lambda(tvm::tir::PrimFunc, tvm::IRModule const&, tvm::transform::PassContext const&)#1}>::_M_invoke(std::_Any_data const&, tvm::tir::PrimFunc&&, tvm::IRModule&&, tvm::transform::PassContext&&)
  File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::Run()
  File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::InferInFreeMode(tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void>&, tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void> const&)
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: Check failed: (min_reg_num < INT64_MAX) is false: no available layout found

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions