-
Notifications
You must be signed in to change notification settings - Fork 337
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Questions
Problem Description
When implementing a Heap-based TopK kernel in TileLang for attention mechanisms, I've observed inconsistent behavior depending on the topk parameter value:
- When
topk=16: The kernel works correctly, producing valid topk scores and indices that match reference implementations. - When
topk=8: The same kernel logic produces incorrect results, often containing-infvalues or wrong indices.
The algorithm is implemented correctly (verified in Python/NumPy), and the same kernel structure works for topk=16, suggesting this may be a compiler optimization or code generation issue.
Important Observation: During debugging, I discovered that changing the loop from for tx in T.Parallel(BLOCK_L): to for tx in T.serial(BLOCK_L): in the heap update section resolves the issue for topk=8. However, this is suboptimal for performance since each row is processed independently and should be parallelizable. This suggests the problem may be related to how TileLang handles parallel loops with certain loop bounds.
Environment Information
- TileLang Version: [tilelang-0.1.6.post2+cu121.gitb8240b7a]
- CUDA Version: [12.1]
- GPU: [NVIDIA RTX 3090]
Minimal Reproduction Example
Here's a minimal reproduction script that demonstrates the issue:
import torch
import tilelang
import tilelang.language as T
from typing import Optional
@tilelang.jit(
out_idx=[2, 3],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
)
def minimal_heap_kernel(batch, seq_len, s_len, heads, head_dim, topk):
dtype = "bfloat16"
accum_dtype = "float"
idx_dtype = "int32"
q_shape = [batch, seq_len, heads, head_dim]
k_shape = [batch, s_len, heads, head_dim]
out_scores_shape = [batch, seq_len, heads, topk]
out_indices_shape = [batch, seq_len, heads, topk]
BLOCK_L = 32
BLOCK_S = 8
BLOCK_D = head_dim
threads = BLOCK_L
num_s_blocks = tilelang.cdiv(s_len, BLOCK_S)
@T.prim_func
def kernel(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
OutScores: T.Tensor(out_scores_shape, accum_dtype),
OutIndices: T.Tensor(out_indices_shape, idx_dtype),
):
with T.Kernel(tilelang.cdiv(seq_len, BLOCK_L), heads, batch, threads=threads) as (bx, by, bz):
i_b = bz
i_h = by
base_l = bx * BLOCK_L
Q_shared = T.alloc_shared([BLOCK_L, BLOCK_D], dtype)
K_shared = T.alloc_shared([BLOCK_S, BLOCK_D], dtype)
score_shared = T.alloc_shared([BLOCK_L, BLOCK_S], accum_dtype)
topk_scores = T.alloc_shared([BLOCK_L, topk], accum_dtype)
topk_indices = T.alloc_shared([BLOCK_L, topk], idx_dtype)
acc_s = T.alloc_fragment([BLOCK_L, BLOCK_S], accum_dtype)
temp_score = T.alloc_var(accum_dtype)
temp_idx = T.alloc_var(idx_dtype)
cur_pos = T.alloc_var(idx_dtype)
smallest_pos = T.alloc_var(idx_dtype)
left_pos = T.alloc_var(idx_dtype)
right_pos = T.alloc_var(idx_dtype)
for l_idx, k_idx in T.Parallel(BLOCK_L, topk):
topk_scores[l_idx, k_idx] = -T.infinity(accum_dtype)
topk_indices[l_idx, k_idx] = -1
for l_idx, d in T.Parallel(BLOCK_L, BLOCK_D):
tq = base_l + l_idx
if tq < seq_len:
Q_shared[l_idx, d] = Q[i_b, tq, i_h, d]
else:
Q_shared[l_idx, d] = 0
for s_block in T.serial(num_s_blocks):
base_s = s_block * BLOCK_S
for s_idx, d in T.Parallel(BLOCK_S, BLOCK_D):
ts = base_s + s_idx
if ts < s_len:
K_shared[s_idx, d] = K[i_b, ts, i_h, d]
else:
K_shared[s_idx, d] = 0
T.sync_threads()
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(acc_s, score_shared)
T.sync_threads()
for tx in T.Parallel(BLOCK_L):
my_l_idx = tx
my_tq = base_l + my_l_idx
if my_tq < seq_len:
for s_idx in T.serial(BLOCK_S):
ts = base_s + s_idx
if ts < s_len:
cur_score = score_shared[my_l_idx, s_idx]
if cur_score > topk_scores[my_l_idx, 0]:
topk_scores[my_l_idx, 0] = cur_score
topk_indices[my_l_idx, 0] = ts
cur_pos = 0
for ki in T.serial(topk):
smallest_pos = cur_pos
left_pos = cur_pos * 2 + 1
right_pos = cur_pos * 2 + 2
if left_pos < topk:
if topk_scores[my_l_idx, left_pos] < topk_scores[my_l_idx, smallest_pos]:
smallest_pos = left_pos
if right_pos < topk:
if topk_scores[my_l_idx, right_pos] < topk_scores[my_l_idx, smallest_pos]:
smallest_pos = right_pos
if smallest_pos != cur_pos:
temp_score = topk_scores[my_l_idx, cur_pos]
temp_idx = topk_indices[my_l_idx, cur_pos]
topk_scores[my_l_idx, cur_pos] = topk_scores[my_l_idx, smallest_pos]
topk_indices[my_l_idx, cur_pos] = topk_indices[my_l_idx, smallest_pos]
topk_scores[my_l_idx, smallest_pos] = temp_score
topk_indices[my_l_idx, smallest_pos] = temp_idx
cur_pos = smallest_pos
T.sync_threads()
for li, ki in T.Parallel(BLOCK_L, topk):
lq = base_l + li
if lq < seq_len:
OutScores[i_b, lq, i_h, ki] = topk_scores[li, ki]
OutIndices[i_b, lq, i_h, ki] = topk_indices[li, ki]
return kernel
def create_test_data(batch=1, seq_len=64, s_len=64, heads=1, head_dim=16):
torch.manual_seed(42)
Q = torch.randn(batch, seq_len, heads, head_dim, dtype=torch.bfloat16, device='cuda')
K = torch.randn(batch, s_len, heads, head_dim, dtype=torch.bfloat16, device='cuda')
Q = Q / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
K = K / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
return Q, K
def reference_topk(Q, K, topk):
batch, seq_len, heads, head_dim = Q.shape
scores = torch.einsum("bqhd,bkhd->bqhk", Q.float(), K.float())
topk_scores, topk_indices = torch.topk(scores, k=topk, dim=-1)
return topk_scores, topk_indices
def test_topk(topk_values):
batch = 1
seq_len = 512
s_len = seq_len
heads = 1
head_dim = 128
Q, K = create_test_data(batch, seq_len, s_len, heads, head_dim)
for topk in topk_values:
print(f"\n{'='*60}")
print(f"Testing topk={topk}")
print(f"{'='*60}")
# 编译kernel
kernel = minimal_heap_kernel(batch, seq_len, s_len, heads, head_dim, topk)
# 准备输出
out_scores = torch.empty(batch, seq_len, heads, topk,
dtype=torch.float32, device='cuda')
out_indices = torch.empty(batch, seq_len, heads, topk,
dtype=torch.int32, device='cuda')
out_scores, out_indices=kernel(Q, K)
ref_scores, ref_indices = reference_topk(Q, K, topk)
print(f"Output shape: {out_scores.shape}")
print(f"Ref shape: {ref_scores.shape}")
has_inf = torch.any(out_scores.isinf())
has_neg_one = torch.any(out_indices == -1)
if has_inf:
print("❌ Kernel output contains -inf")
if has_neg_one:
print("❌ Kernel output contains -1 indices")
fused_set = set(out_indices[0, 0, 0].cpu().numpy())
ref_set = set(ref_indices[0, 0, 0].cpu().numpy())
match_ratio = len(fused_set.intersection(ref_set)) / topk
print(f"Element set match rate: {match_ratio:.2%}")
print(f"\nKernel indices (first position): {out_indices[0, 0, 0].cpu().numpy()}")
print(f"Reference indices: {ref_indices[0, 0, 0].cpu().numpy()}")
print(f"\nKernel scores: {out_scores[0, 0, 0].cpu().numpy()}")
print(f"Reference scores: {ref_scores[0, 0, 0].cpu().numpy()}")
if __name__ == "__main__":
test_topk([8, 16])Observed Behavior
============================================================
Testing topk=8
============================================================
Output shape: torch.Size([1, 512, 1, 8])
Ref shape: torch.Size([1, 512, 1, 8])
❌ Kernel output contains -inf
❌ Kernel output contains -1 indices
Element set match rate: 87.50%
Kernel indices (first position): [245 288 393 291 157 216 310 -1]
Reference indices: [291 310 157 216 393 288 245 419]
Kernel scores: [0.17208526 0.18024139 0.1807008 0.26306465 0.21612965 0.19079374
0.22436143 -inf]
Reference scores: [0.2630647 0.22436148 0.21612965 0.19079375 0.18070081 0.1802414
0.17208527 0.1612985 ]
============================================================
Testing topk=16
============================================================
Output shape: torch.Size([1, 512, 1, 16])
Ref shape: torch.Size([1, 512, 1, 16])
Element set match rate: 100.00%
Kernel indices (first position): [252 496 391 117 141 393 335 91 288 445 291 157 216 419 310 245]
Reference indices: [291 310 157 216 393 288 245 419 91 335 445 391 141 117 496 252]
Kernel scores: [0.14278702 0.14497198 0.14963867 0.14788239 0.14858896 0.1807008
0.15289012 0.15979369 0.18024139 0.14967075 0.26306465 0.21612965
0.19079374 0.16129847 0.22436143 0.17208526]
Reference scores: [0.2630647 0.22436148 0.21612965 0.19079375 0.18070081 0.1802414
0.17208527 0.1612985 0.1597937 0.15289015 0.14967078 0.1496387
0.14858896 0.14788243 0.14497203 0.14278702]
Expected Behavior
The heap-based TopK kernel should produce correct results for all values of topk, not just for certain values like 16.
Specifically:
- No
-infvalues in the output scores - No
-1values in the output indices (unless there are actually fewer thantopkvalid elements) - The set of indices returned should match the reference implementation
- The scores should be close to the reference values (allow for small floating-point differences)
Please let me know if you need any additional information or if there are specific tests I should run to help diagnose this issue.