@@ -849,6 +849,315 @@ __global__ void append_decode_cache_T_quant_neox_rope_kernel(
849849#endif
850850}
851851
852+ template <typename T,
853+ int VecSize = 4 ,
854+ int RoundType = 0 ,
855+ int HeadDim = 128 ,
856+ bool is_scale_channel_wise = false ,
857+ bool IsFP8 = true ,
858+ bool IsDynamic = true >
859+ __global__ void append_decode_cache_T_int8_neox_rope_kernel (
860+ const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
861+ // head_size]
862+ uint8_t * __restrict__ key_cache, // [num_blocks, kv_num_heads,
863+ // block_size, head_size // 2]
864+ uint8_t * __restrict__ value_cache, // [num_blocks, kv_num_heads,
865+ // block_size, head_size // 2]
866+ T* __restrict__ qkv_out,
867+ const int * __restrict__ block_tables, // [bsz, max_blocks_per_seq]
868+ const int * __restrict__ cu_seqlens_q,
869+ const int * __restrict__ seq_lens, // [bsz]
870+ const int * __restrict__ seq_lens_encoder, // [bsz]
871+ const float * __restrict__ cos_emb,
872+ const float * __restrict__ sin_emb,
873+ T* __restrict__ cache_k_scale,
874+ T* __restrict__ cache_v_scale,
875+ const int max_seq_len,
876+ const int max_blocks_per_seq,
877+ const int num_heads,
878+ const int block_size,
879+ const float max_bound,
880+ const float min_bound,
881+ const int kv_num_heads,
882+ const bool rope_3d,
883+ const float rms_norm_eps) {
884+ static_assert (HeadDim == 128 , " just support HeadDim be 128 now!" );
885+ static_assert (VecSize == 4 , " just support VecSize be 4 now, 32 * 4!" );
886+ constexpr int NUM_WARPS = 4 ;
887+ const int tid = threadIdx .x ;
888+ const int wid = tid / 32 ;
889+ const int lane_id = tid % 32 ;
890+ const int bid = blockIdx .x , head_idx = blockIdx .y * NUM_WARPS + wid;
891+ int q_head_idx, k_head_idx, v_idx;
892+ const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
893+ constexpr int half_head_size = HeadDim / 2 ;
894+ const int start_token_idx = cu_seqlens_q[bid];
895+ if (seq_lens_encoder[bid] > 0 ) return ;
896+ const int write_seq_id = seq_lens[bid];
897+ if (write_seq_id == 0 ) return ;
898+ const int * block_table_now = nullptr ;
899+
900+ block_table_now = block_tables + bid * max_blocks_per_seq;
901+ const int block_idx = __ldg (&block_table_now[write_seq_id / block_size]);
902+ const int block_offset = write_seq_id % block_size;
903+
904+ float thread_m2 = 0 .0f ;
905+ float warp_m2 = 0 .0f ;
906+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
907+ cudaGridDependencySynchronize ();
908+ #endif
909+ if (head_idx < num_heads) {
910+ // q
911+ using LoadT = AlignedVector<T, VecSize>;
912+ using LoadBiasT = AlignedVector<T, VecSize>;
913+ constexpr int HalfVecSize = VecSize / 2 ;
914+ using LoadEmbT = AlignedVector<float , VecSize>;
915+
916+ LoadT src_vec;
917+ LoadT src_vec_right;
918+ LoadBiasT out_vec;
919+ LoadBiasT out_vec_right;
920+ LoadEmbT cos_emb_vec;
921+ LoadEmbT sin_emb_vec;
922+ const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
923+ T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
924+ #pragma unroll
925+ for (uint32_t head_bias = lane_id * VecSize; head_bias < half_head_size;
926+ head_bias += 32 * VecSize) {
927+ const int bias_idx = head_idx * HeadDim + head_bias;
928+ Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
929+ Load<T, VecSize>(&qkv_now[bias_idx + half_head_size], &src_vec_right);
930+ // q rope
931+ const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
932+ const uint32_t new_emb_idx =
933+ rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
934+ Load<float , VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
935+ Load<float , VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
936+ #pragma unroll
937+ for (int i = 0 ; i < VecSize; i++) {
938+ // dequant + add_bias + rope
939+ float input_left = static_cast <float >(src_vec[i]);
940+ float input_right = static_cast <float >(src_vec_right[i]);
941+
942+ const float cos_tmp = cos_emb_vec[i];
943+ const float sin_tmp = sin_emb_vec[i];
944+ float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
945+ float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
946+ thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
947+ out_vec[i] = static_cast <T>(tmp1);
948+ out_vec_right[i] = static_cast <T>(tmp2);
949+ }
950+ Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
951+ Store<T, VecSize>(out_vec_right, &qkv_out_now[bias_idx + half_head_size]);
952+ }
953+ } else if (head_idx < num_heads + 2 * kv_num_heads) {
954+ // k
955+ constexpr int KV_VEC_SIZE = 16 / sizeof (uint8_t ); // 16
956+ using LoadPadKVT = AlignedVector<uint8_t , KV_VEC_SIZE>;
957+ const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
958+ if (block_offset == 0 ) {
959+ // pad zero for this kv_head_idx for this block
960+ LoadPadKVT pad_cache_vec;
961+ *(reinterpret_cast <uint4 *>(pad_cache_vec.val )) = make_uint4 (0 , 0 , 0 , 0 );
962+ if (head_idx < num_heads + kv_num_heads) {
963+ constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
964+ constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
965+ const uint32_t tgt_idx =
966+ (block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
967+ lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
968+ for (int block_i = lane_id / num_vecs_per_head_dim;
969+ block_i < block_size;
970+ block_i += num_token_each_time) {
971+ Store<uint8_t , KV_VEC_SIZE>(pad_cache_vec,
972+ &key_cache[tgt_idx + block_i * HeadDim]);
973+ }
974+ } else {
975+ const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
976+ const int num_token_each_time = 32 / num_vecs_per_head_dim;
977+ const uint32_t tgt_idx =
978+ (block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
979+ lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
980+ for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
981+ block_i += num_token_each_time) {
982+ Store<uint8_t , KV_VEC_SIZE>(
983+ pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
984+ }
985+ }
986+ __syncwarp ();
987+ }
988+
989+ constexpr int K_VEC_SIZE = 4 ;
990+ constexpr int HALF_K_VEC_SIZE = 2 ;
991+ using LoadKVResT = AlignedVector<uint8_t , K_VEC_SIZE>;
992+ using LoadKVT = AlignedVector<uint8_t , HALF_K_VEC_SIZE>;
993+ using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
994+ using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
995+ using LoadEmbT = AlignedVector<float , HALF_K_VEC_SIZE>;
996+ LoadKVResT cache_vec;
997+ LoadT src_vec1, src_vec1_right, src_vec2, src_vec2_right;
998+ LoadBiasT out_vec1, out_vec2;
999+ LoadEmbT cos_emb_vec1, cos_emb_vec2;
1000+ LoadEmbT sin_emb_vec1, sin_emb_vec2;
1001+
1002+ const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
1003+ const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2 ;
1004+ const int bias_idx = head_idx * HeadDim + head_bias;
1005+ Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
1006+ Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8 ], &src_vec2);
1007+ T scale = T (1 .0f );
1008+ const int k_head_idx = head_idx - num_heads;
1009+ const int v_head_idx = head_idx - num_heads - kv_num_heads;
1010+ if (head_idx < num_heads + kv_num_heads) {
1011+ Load<T, HALF_K_VEC_SIZE>(
1012+ &qkv_now[head_idx * HeadDim + (head_bias + half_head_size) % HeadDim],
1013+ &src_vec1_right);
1014+ Load<T, HALF_K_VEC_SIZE>(
1015+ &qkv_now[head_idx * HeadDim +
1016+ (head_bias + 8 + half_head_size) % HeadDim],
1017+ &src_vec2_right);
1018+
1019+ const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
1020+ const uint32_t new_emb_idx =
1021+ rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
1022+ Load<float , HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
1023+ Load<float , HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8 ], &cos_emb_vec2);
1024+ Load<float , HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
1025+ Load<float , HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx + 8 ], &sin_emb_vec2);
1026+ }
1027+
1028+ if (head_idx < num_heads + kv_num_heads) {
1029+ float input_left = static_cast <float >(src_vec1[0 ]);
1030+ float input_right = static_cast <float >(src_vec1_right[0 ]);
1031+ float cos_tmp = cos_emb_vec1[0 ];
1032+ float sin_tmp = sin_emb_vec1[0 ];
1033+ float tmp1 = 0 ;
1034+ if (head_bias < half_head_size) {
1035+ tmp1 = input_left * cos_tmp - input_right * sin_tmp;
1036+ } else {
1037+ tmp1 = input_left * cos_tmp + input_right * sin_tmp;
1038+ }
1039+ out_vec1[0 ] = static_cast <T>(tmp1);
1040+ input_left = static_cast <float >(src_vec1[1 ]);
1041+ input_right = static_cast <float >(src_vec1_right[1 ]);
1042+ cos_tmp = cos_emb_vec1[1 ];
1043+ sin_tmp = sin_emb_vec1[1 ];
1044+ if (head_bias < half_head_size) {
1045+ tmp1 = input_left * cos_tmp - input_right * sin_tmp;
1046+ } else {
1047+ tmp1 = input_left * cos_tmp + input_right * sin_tmp;
1048+ }
1049+ out_vec1[1 ] = static_cast <T>(tmp1);
1050+ } else {
1051+ out_vec1[0 ] = src_vec1[0 ];
1052+ out_vec1[1 ] = src_vec1[1 ];
1053+ }
1054+
1055+ // rope
1056+ if (head_idx < num_heads + kv_num_heads) {
1057+ float input_left = static_cast <float >(src_vec2[0 ]);
1058+ float input_right = static_cast <float >(src_vec2_right[0 ]);
1059+ float cos_tmp = cos_emb_vec2[0 ];
1060+ float sin_tmp = sin_emb_vec2[0 ];
1061+ float tmp1 = 0 ;
1062+ if (head_bias < half_head_size) {
1063+ tmp1 = input_left * cos_tmp - input_right * sin_tmp;
1064+ } else {
1065+ tmp1 = input_left * cos_tmp + input_right * sin_tmp;
1066+ }
1067+ out_vec2[0 ] = static_cast <T>(tmp1);
1068+ input_left = static_cast <float >(src_vec2[1 ]);
1069+ input_right = static_cast <float >(src_vec2_right[1 ]);
1070+ cos_tmp = cos_emb_vec2[1 ];
1071+ sin_tmp = sin_emb_vec2[1 ];
1072+ if (head_bias < half_head_size) {
1073+ tmp1 = input_left * cos_tmp - input_right * sin_tmp;
1074+ } else {
1075+ tmp1 = input_left * cos_tmp + input_right * sin_tmp;
1076+ }
1077+ out_vec2[1 ] = static_cast <T>(tmp1);
1078+ } else {
1079+ out_vec2[0 ] = src_vec2[0 ];
1080+ out_vec2[1 ] = src_vec2[1 ];
1081+ }
1082+ if constexpr (IsDynamic) {
1083+ // reduce max, 1 head per warp
1084+ T local_max = -INFINITY;
1085+ #pragma unroll
1086+ for (int i = 0 ; i < HALF_K_VEC_SIZE; i++) {
1087+ local_max = __hmax (local_max, __habs (out_vec1[i]));
1088+ local_max = __hmax (local_max, __habs (out_vec2[i]));
1089+ }
1090+ #pragma unroll
1091+ for (int m_offset = 16 ; m_offset > 0 ; m_offset /= 2 ) {
1092+ local_max =
1093+ __hmax (local_max, __shfl_xor_sync (0xffffffff , local_max, m_offset));
1094+ }
1095+ scale = __hdiv (448 , local_max);
1096+
1097+ int cache_offset;
1098+ if (head_idx < num_heads) {
1099+ cache_offset = 0 ;
1100+ } else if (head_idx < num_heads + 2 * kv_num_heads) {
1101+ cache_offset = block_idx * kv_num_heads * block_size +
1102+ (head_idx - num_heads) % kv_num_heads * block_size +
1103+ block_offset;
1104+ }
1105+ T* cache_k_scale_now = cache_k_scale + cache_offset;
1106+ T* cache_v_scale_now = cache_v_scale + cache_offset;
1107+ if (lane_id == 0 ) {
1108+ if (head_idx < num_heads + kv_num_heads) {
1109+ cache_k_scale_now[0 ] = __hdiv (1 , scale);
1110+ } else {
1111+ cache_v_scale_now[0 ] = __hdiv (1 , scale);
1112+ }
1113+ }
1114+ } else {
1115+ if (head_idx < num_heads + kv_num_heads) {
1116+ scale = __ldg (&cache_k_scale[kv_head_idx]);
1117+ } else {
1118+ scale = __ldg (&cache_v_scale[kv_head_idx]);
1119+ }
1120+ }
1121+
1122+ #pragma unroll
1123+ for (uint32_t i = 0 ; i < HALF_K_VEC_SIZE; i++) {
1124+ cache_vec[i] = QuantToC8<T, true , IsFP8, RoundType>(
1125+ scale, out_vec1[i], max_bound, min_bound);
1126+ cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true , IsFP8, RoundType>(
1127+ scale, out_vec2[i], max_bound, min_bound);
1128+ }
1129+ if (head_idx < num_heads + kv_num_heads) {
1130+ const int start_block_16 =
1131+ block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8 ;
1132+ const uint32_t tgt_cache_idx =
1133+ block_idx * kv_num_heads * block_size * HeadDim +
1134+ kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
1135+ lane_id / 4 / 2 * 32 + (block_offset % 16 ) / 8 * 16 + lane_id % 4 * 4 ;
1136+ Store<uint8_t , K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
1137+ } else {
1138+ const uint32_t base_tgt_cache_idx =
1139+ block_idx * kv_num_heads * HeadDim * block_size +
1140+ kv_head_idx * HeadDim * block_size +
1141+ (lane_id / 4 * 16 + lane_id % 4 * 2 ) * block_size +
1142+ block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32 ;
1143+ const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
1144+ block_offset % 8 / 2 * 4 // per 4
1145+ + block_offset % 16 / 8 * 2 // per 2
1146+ + block_offset % 2 ; // per 1
1147+ const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
1148+ const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16 ;
1149+ const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
1150+ value_cache[tgt_cache_idx1] = cache_vec[0 ];
1151+ value_cache[tgt_cache_idx2] = cache_vec[1 ];
1152+ value_cache[tgt_cache_idx3] = cache_vec[2 ];
1153+ value_cache[tgt_cache_idx4] = cache_vec[3 ];
1154+ }
1155+ }
1156+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1157+ cudaTriggerProgrammaticLaunchCompletion ();
1158+ #endif
1159+ }
1160+
8521161template <typename T,
8531162 int VecSize = 4 ,
8541163 int RoundType = 0 ,
0 commit comments