Skip to content

[Question] Heap Sort Kernel Shows Different Behavior for Different Config #1370

@WXY277

Description

@WXY277

Required prerequisites

Questions

Problem Description

When implementing a Heap-based TopK kernel in TileLang for attention mechanisms, I've observed inconsistent behavior depending on the topk parameter value:

  • When topk=16: The kernel works correctly, producing valid topk scores and indices that match reference implementations.
  • When topk=8: The same kernel logic produces incorrect results, often containing -inf values or wrong indices.

The algorithm is implemented correctly (verified in Python/NumPy), and the same kernel structure works for topk=16, suggesting this may be a compiler optimization or code generation issue.

Important Observation: During debugging, I discovered that changing the loop from for tx in T.Parallel(BLOCK_L): to for tx in T.serial(BLOCK_L): in the heap update section resolves the issue for topk=8. However, this is suboptimal for performance since each row is processed independently and should be parallelizable. This suggests the problem may be related to how TileLang handles parallel loops with certain loop bounds.

Environment Information

- TileLang Version: [tilelang-0.1.6.post2+cu121.gitb8240b7a]
- CUDA Version: [12.1]
- GPU: [NVIDIA RTX 3090]

Minimal Reproduction Example

Here's a minimal reproduction script that demonstrates the issue:

import torch
import tilelang
import tilelang.language as T
from typing import Optional

@tilelang.jit(
    out_idx=[2, 3],
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
    }
)
def minimal_heap_kernel(batch, seq_len, s_len, heads, head_dim, topk):
    dtype = "bfloat16"
    accum_dtype = "float"
    idx_dtype = "int32"
    
    q_shape = [batch, seq_len, heads, head_dim]
    k_shape = [batch, s_len, heads, head_dim]
    out_scores_shape = [batch, seq_len, heads, topk]
    out_indices_shape = [batch, seq_len, heads, topk]

    BLOCK_L = 32
    BLOCK_S = 8
    BLOCK_D = head_dim
    threads = BLOCK_L
    
    num_s_blocks = tilelang.cdiv(s_len, BLOCK_S)

    @T.prim_func
    def kernel(
        Q: T.Tensor(q_shape, dtype),
        K: T.Tensor(k_shape, dtype),
        OutScores: T.Tensor(out_scores_shape, accum_dtype),
        OutIndices: T.Tensor(out_indices_shape, idx_dtype),
    ):
        with T.Kernel(tilelang.cdiv(seq_len, BLOCK_L), heads, batch, threads=threads) as (bx, by, bz):
            i_b = bz
            i_h = by
            base_l = bx * BLOCK_L
            
            Q_shared = T.alloc_shared([BLOCK_L, BLOCK_D], dtype)
            K_shared = T.alloc_shared([BLOCK_S, BLOCK_D], dtype)
            score_shared = T.alloc_shared([BLOCK_L, BLOCK_S], accum_dtype)
            
            topk_scores = T.alloc_shared([BLOCK_L, topk], accum_dtype)
            topk_indices = T.alloc_shared([BLOCK_L, topk], idx_dtype)
            
            acc_s = T.alloc_fragment([BLOCK_L, BLOCK_S], accum_dtype)
            
            temp_score = T.alloc_var(accum_dtype)
            temp_idx = T.alloc_var(idx_dtype)
            cur_pos = T.alloc_var(idx_dtype)
            smallest_pos = T.alloc_var(idx_dtype)
            left_pos = T.alloc_var(idx_dtype)
            right_pos = T.alloc_var(idx_dtype)

            for l_idx, k_idx in T.Parallel(BLOCK_L, topk):
                topk_scores[l_idx, k_idx] = -T.infinity(accum_dtype)
                topk_indices[l_idx, k_idx] = -1
            
            for l_idx, d in T.Parallel(BLOCK_L, BLOCK_D):
                tq = base_l + l_idx
                if tq < seq_len:
                    Q_shared[l_idx, d] = Q[i_b, tq, i_h, d]
                else:
                    Q_shared[l_idx, d] = 0
            
            for s_block in T.serial(num_s_blocks):
                base_s = s_block * BLOCK_S
                
                for s_idx, d in T.Parallel(BLOCK_S, BLOCK_D):
                    ts = base_s + s_idx
                    if ts < s_len:
                        K_shared[s_idx, d] = K[i_b, ts, i_h, d]
                    else:
                        K_shared[s_idx, d] = 0
                
                T.sync_threads()
                
                T.clear(acc_s)
                T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
                T.copy(acc_s, score_shared)
                
                T.sync_threads()
                
                for tx in T.Parallel(BLOCK_L):
                    my_l_idx = tx
                    my_tq = base_l + my_l_idx
                    
                    if my_tq < seq_len:
                        for s_idx in T.serial(BLOCK_S):
                            ts = base_s + s_idx
                            if ts < s_len:
                                cur_score = score_shared[my_l_idx, s_idx]
                                
                                if cur_score > topk_scores[my_l_idx, 0]:
                                    topk_scores[my_l_idx, 0] = cur_score
                                    topk_indices[my_l_idx, 0] = ts
                                    
                                    cur_pos = 0
                                    for ki in T.serial(topk):
                                        smallest_pos = cur_pos
                                        left_pos = cur_pos * 2 + 1
                                        right_pos = cur_pos * 2 + 2
                                        
                                        if left_pos < topk:
                                            if topk_scores[my_l_idx, left_pos] < topk_scores[my_l_idx, smallest_pos]:
                                                smallest_pos = left_pos
                                        
                                        if right_pos < topk:
                                            if topk_scores[my_l_idx, right_pos] < topk_scores[my_l_idx, smallest_pos]:
                                                smallest_pos = right_pos
                                        
                                        if smallest_pos != cur_pos:
                                            temp_score = topk_scores[my_l_idx, cur_pos]
                                            temp_idx = topk_indices[my_l_idx, cur_pos]
                                            topk_scores[my_l_idx, cur_pos] = topk_scores[my_l_idx, smallest_pos]
                                            topk_indices[my_l_idx, cur_pos] = topk_indices[my_l_idx, smallest_pos]
                                            topk_scores[my_l_idx, smallest_pos] = temp_score
                                            topk_indices[my_l_idx, smallest_pos] = temp_idx
                                            cur_pos = smallest_pos
                
                T.sync_threads()
            
            for li, ki in T.Parallel(BLOCK_L, topk):
                lq = base_l + li
                if lq < seq_len:
                    OutScores[i_b, lq, i_h, ki] = topk_scores[li, ki]
                    OutIndices[i_b, lq, i_h, ki] = topk_indices[li, ki]

    return kernel

def create_test_data(batch=1, seq_len=64, s_len=64, heads=1, head_dim=16):

    torch.manual_seed(42)
    
    Q = torch.randn(batch, seq_len, heads, head_dim, dtype=torch.bfloat16, device='cuda')
    K = torch.randn(batch, s_len, heads, head_dim, dtype=torch.bfloat16, device='cuda')
    
    Q = Q / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
    K = K / torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
    
    return Q, K



def reference_topk(Q, K, topk):
    batch, seq_len, heads, head_dim = Q.shape
    scores = torch.einsum("bqhd,bkhd->bqhk", Q.float(), K.float())

    topk_scores, topk_indices = torch.topk(scores, k=topk, dim=-1)
    
    return topk_scores, topk_indices


def test_topk(topk_values):
    batch = 1
    seq_len = 512
    s_len = seq_len
    heads = 1
    head_dim = 128
    
    Q, K = create_test_data(batch, seq_len, s_len, heads, head_dim)
    
    for topk in topk_values:
        print(f"\n{'='*60}")
        print(f"Testing topk={topk}")
        print(f"{'='*60}")
        
        # 编译kernel
        kernel = minimal_heap_kernel(batch, seq_len, s_len, heads, head_dim, topk)
        
        # 准备输出
        out_scores = torch.empty(batch, seq_len, heads, topk, 
                                dtype=torch.float32, device='cuda')
        out_indices = torch.empty(batch, seq_len, heads, topk,
                                 dtype=torch.int32, device='cuda')
        
        out_scores, out_indices=kernel(Q, K)
        
        ref_scores, ref_indices = reference_topk(Q, K, topk)
        
        print(f"Output shape: {out_scores.shape}")
        print(f"Ref shape: {ref_scores.shape}")
        
        has_inf = torch.any(out_scores.isinf())
        has_neg_one = torch.any(out_indices == -1)
        
        if has_inf:
            print("❌ Kernel output contains -inf")
        if has_neg_one:
            print("❌ Kernel output contains -1 indices")
        
        fused_set = set(out_indices[0, 0, 0].cpu().numpy())
        ref_set = set(ref_indices[0, 0, 0].cpu().numpy())
        
        match_ratio = len(fused_set.intersection(ref_set)) / topk
        print(f"Element set match rate: {match_ratio:.2%}")
        
        print(f"\nKernel indices (first position): {out_indices[0, 0, 0].cpu().numpy()}")
        print(f"Reference indices: {ref_indices[0, 0, 0].cpu().numpy()}")
        
        print(f"\nKernel scores: {out_scores[0, 0, 0].cpu().numpy()}")
        print(f"Reference scores: {ref_scores[0, 0, 0].cpu().numpy()}")
            

if __name__ == "__main__":
    test_topk([8, 16])

Observed Behavior

============================================================
Testing topk=8
============================================================
Output shape: torch.Size([1, 512, 1, 8])
Ref shape: torch.Size([1, 512, 1, 8])
❌ Kernel output contains -inf
❌ Kernel output contains -1 indices
Element set match rate: 87.50%

Kernel indices (first position): [245 288 393 291 157 216 310  -1]
Reference indices: [291 310 157 216 393 288 245 419]

Kernel scores: [0.17208526 0.18024139 0.1807008  0.26306465 0.21612965 0.19079374
 0.22436143       -inf]
Reference scores: [0.2630647  0.22436148 0.21612965 0.19079375 0.18070081 0.1802414
 0.17208527 0.1612985 ]

============================================================
Testing topk=16
============================================================
Output shape: torch.Size([1, 512, 1, 16])
Ref shape: torch.Size([1, 512, 1, 16])
Element set match rate: 100.00%

Kernel indices (first position): [252 496 391 117 141 393 335  91 288 445 291 157 216 419 310 245]
Reference indices: [291 310 157 216 393 288 245 419  91 335 445 391 141 117 496 252]

Kernel scores: [0.14278702 0.14497198 0.14963867 0.14788239 0.14858896 0.1807008
 0.15289012 0.15979369 0.18024139 0.14967075 0.26306465 0.21612965
 0.19079374 0.16129847 0.22436143 0.17208526]
Reference scores: [0.2630647  0.22436148 0.21612965 0.19079375 0.18070081 0.1802414
 0.17208527 0.1612985  0.1597937  0.15289015 0.14967078 0.1496387
 0.14858896 0.14788243 0.14497203 0.14278702]

Expected Behavior

The heap-based TopK kernel should produce correct results for all values of topk, not just for certain values like 16.

Specifically:

  1. No -inf values in the output scores
  2. No -1 values in the output indices (unless there are actually fewer than topk valid elements)
  3. The set of indices returned should match the reference implementation
  4. The scores should be close to the reference values (allow for small floating-point differences)

Please let me know if you need any additional information or if there are specific tests I should run to help diagnose this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions