-
Notifications
You must be signed in to change notification settings - Fork 335
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post2+cuda.gite2b10c58
System information
3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post2+cuda.gite2b10c58
2.7.0+cu128
Problem description
When using T.alloc_fragment for local memory allocation, the generated CUDA code shows incorrect vectorization. The A_local array access is not vectorized while A_shared access is properly vectorized with int4 loads.
Reproducible example code
The Python snippets:
import tilelang as tl
import tilelang.language as T
from tilelang import tvm as tvm
tl.disable_cache()
@tl.jit
def get_wrong_kernel(M: int = 4096):
dtype = "int32"
num_threads = 32
numel = M // 4
@T.prim_func
def expected(A: T.Tensor((M, ), dtype=dtype), B: T.Tensor((M, ), dtype=dtype)):
with T.Kernel(1, threads=num_threads) as (bx, ):
A_shared = T.alloc_shared((M, ), dtype=dtype)
A_local = T.alloc_fragment((M, ), dtype=dtype)
thread_idx = T.get_thread_binding()
T.copy(A, A_shared)
T.copy(A_shared, A_local)
for i in T.serial(T.ceildiv(numel - thread_idx, num_threads)):
id = thread_idx + i * num_threads
id_in = id * 4
for j in T.vectorized(4):
B[id_in + j] = A_shared[id_in + j]
for j in T.vectorized(4):
B[id_in + j] = A_local[id_in + j]
for j in T.vectorized(4):
B[id_in + j] = A[id_in + j]
return expected
kernel = get_wrong_kernel()
print(kernel.get_kernel_source())Generated CUDA code snippets:
#pragma unroll
for (int i = 0; i < 32; ++i) {
*(int4*)(A_local + (i * 4)) = *(int4*)(A_shared + ((i * 128) + (((int)threadIdx.x) * 4)));
}
for (int i_1 = 0; i_1 < ((1055 - ((int)threadIdx.x)) >> 5); ++i_1) {
*(int4*)(B + ((i_1 * 128) + (((int)threadIdx.x) * 4))) = *(int4*)(A_shared + ((i_1 * 128) + (((int)threadIdx.x) * 4)));
for (int j = 0; j < 4; ++j) {
B[(((i_1 * 128) + (((int)threadIdx.x) * 4)) + j)] = A_local[((i_1 * 4) + j)]; // Not vectorized
}
*(int4*)(B + ((i_1 * 128) + (((int)threadIdx.x) * 4))) = *(int4*)(A + ((i_1 * 128) + (((int)threadIdx.x) * 4)));
}Traceback
Expected behavior
No response
Additional context
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working