Skip to content

Commit 635126f

Browse files
authored
Merge branch 'develop' into sgl
2 parents c8fccdc + 9f4512c commit 635126f

File tree

146 files changed

+6925
-2891
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

146 files changed

+6925
-2891
lines changed

benchmarks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ python -m pip install -r requirements.txt
4545
--debug:开启debug模式,逐条打印payload和output内容,默认False
4646
--shuffle:是否打乱数据集,默认False不打乱
4747
--seed:打乱数据集时的随机种子,默认0
48+
--pd-metrics:开启PD分离metrics指标收集,会添加请求参数collect_metrics=True,默认False
4849
```
4950

5051
##### /v1/chat/completions接口压测单条数据调试

benchmarks/backend_request_func.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class RequestFuncInput:
5151
ignore_eos: bool = False
5252
language: Optional[str] = None
5353
debug: bool = False
54+
pd_metrics: bool = False
5455
response_format: Optional[dict] = None
5556
random_flag: bool = False
5657

@@ -74,6 +75,73 @@ class RequestFuncOutput:
7475
prompt_len: int = 0
7576
prompt_tokens: int = 0 # 推理侧返回输入token数
7677
error: str = ""
78+
metrics: dict = field(default_factory=dict)
79+
80+
81+
def safe_cost(a, b):
82+
"""时间差计算"""
83+
if a is None or b is None:
84+
return None
85+
return a - b
86+
87+
88+
def metrics_summary(metrics, token_timestamps):
89+
"""Summarize metrics"""
90+
if not metrics or len(token_timestamps) < 2:
91+
return {}
92+
93+
m0 = metrics[0]
94+
m_last = metrics[-1]
95+
96+
summary = {}
97+
98+
arrival_time = m0.get("arrival_time")
99+
inference_start_time = m0.get("inference_start_time")
100+
101+
# prefill 总耗时
102+
summary["prefill_cost_time"] = safe_cost(m0.get("send_request_output_to_decode_time"), arrival_time)
103+
# prefill准备耗时
104+
summary["prefill_prepare_cost_time"] = safe_cost(inference_start_time, arrival_time)
105+
# 预处理耗时
106+
summary["preprocess_cost_time"] = safe_cost(m0.get("scheduler_recv_req_time"), arrival_time)
107+
# 请求缓存耗时
108+
summary["cache_in_scheduler_cost_time"] = safe_cost(
109+
m0.get("engine_get_req_time"), m0.get("scheduler_recv_req_time")
110+
)
111+
# 申请 decode资源耗时
112+
summary["ask_decode_resource_cost_time"] = safe_cost(
113+
m0.get("ask_decode_resource_finish_time"), m0.get("ask_decode_resource_start_time")
114+
)
115+
# prefill 的首 token 推理耗时
116+
summary["prefill_first_token_infer_cost_time"] = safe_cost(
117+
m0.get("engine_recv_first_token_time"), inference_start_time
118+
)
119+
# prefill 等待 cache 传输耗时
120+
summary["wait_sending_cache_cost_time"] = safe_cost(
121+
m0.get("send_request_output_to_decode_time"), m0.get("wait_for_sending_cache_time")
122+
)
123+
# decode分配资源耗时
124+
summary["decode_preallocate_cost_time"] = safe_cost(
125+
m_last.get("decode_preallocate_req_time"), m_last.get("decode_recv_req_time")
126+
)
127+
# decode准备推理耗时
128+
summary["decode_prepare_cost_time"] = safe_cost(
129+
m_last.get("decode_inference_start_time"), m_last.get("decode_recv_first_token_time")
130+
)
131+
# decode次token推理耗时
132+
summary["decode_second_token_infer_cost_time"] = safe_cost(
133+
m_last.get("decode_recv_second_token_time"), m_last.get("decode_inference_start_time")
134+
)
135+
# 返回首 token 链路耗时
136+
summary["first_token_transmission_cost_time"] = safe_cost(
137+
token_timestamps[0], m_last.get("decode_recv_first_token_time")
138+
)
139+
# 返回次 token 链路耗时
140+
summary["second_token_transmission_cost_time"] = safe_cost(
141+
token_timestamps[1], m_last.get("decode_recv_second_token_time")
142+
)
143+
144+
return summary
77145

78146

79147
async def async_request_eb_openai_chat_completions(
@@ -97,6 +165,7 @@ async def async_request_eb_openai_chat_completions(
97165
"continuous_usage_stats": True,
98166
},
99167
"max_tokens": request_func_input.output_len,
168+
"collect_metrics": request_func_input.pd_metrics,
100169
}
101170
if request_func_input.response_format:
102171
payload["response_format"] = request_func_input.response_format
@@ -125,11 +194,13 @@ async def async_request_eb_openai_chat_completions(
125194
output = RequestFuncOutput()
126195
output.prompt_len = 0
127196
output.no = request_func_input.no
197+
metrics_list = []
128198
request_id = "None"
129199

130200
ttft = 0.0
131201
st = time.perf_counter()
132202
most_recent_timestamp = st
203+
token_timestamps = []
133204
try:
134205
async with session.post(url=api_url, json=payload, headers=headers) as response:
135206
data = {}
@@ -144,6 +215,10 @@ async def async_request_eb_openai_chat_completions(
144215
# print("####chunk:", chunk, type(chunk))
145216
timestamp = time.perf_counter()
146217
data = json.loads(chunk)
218+
# print("####data:", json.dumps(data, indent=2, ensure_ascii=False))
219+
220+
if "metrics" in data:
221+
metrics_list.append(data["metrics"])
147222

148223
if request_id == "None" and "id" in data:
149224
request_id = data["id"]
@@ -169,16 +244,22 @@ async def async_request_eb_openai_chat_completions(
169244

170245
output.generated_text += content or ""
171246
output.reasoning_content += reason_content or ""
247+
# print(f"####content:{data}")
172248
output.arrival_time.append(choices[0].get("arrival_time", timestamp))
173249
elif usage := data.get("usage", {}):
174250
output.output_tokens = usage.get("completion_tokens", 0)
175251
output.prompt_tokens = usage.get("prompt_tokens", 0)
176252

177253
most_recent_timestamp = timestamp
254+
token_timestamps.append(time.time())
178255

179256
# output.generated_text = generated_text
180257
# 在流式结束时,记录最后一个 chunk 收到的时间戳
181258
output.end_timestamp = most_recent_timestamp
259+
260+
# 新增metrics统计,计算首token过滤空包
261+
output.metrics = metrics_summary(metrics_list, token_timestamps[1:])
262+
182263
if output.generated_text.strip() == "":
183264
output.success = False
184265
output.error = "No generated text found!"

benchmarks/benchmark_serving.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ async def benchmark(
318318
selected_percentiles: list[float],
319319
ignore_eos: bool,
320320
debug: bool,
321+
pd_metrics: bool,
321322
goodput_config_dict: dict[str, float],
322323
max_concurrency: Optional[int],
323324
lora_modules: Optional[Iterable[str]],
@@ -352,6 +353,7 @@ async def benchmark(
352353
logprobs=logprobs,
353354
ignore_eos=ignore_eos,
354355
debug=debug,
356+
pd_metrics=pd_metrics,
355357
extra_body=extra_body,
356358
response_format=response_format,
357359
random_flag=random_flag,
@@ -446,6 +448,7 @@ async def limited_request_func(request_func_input, pbar):
446448
output_len=output_len,
447449
logprobs=logprobs,
448450
debug=debug,
451+
pd_metrics=pd_metrics,
449452
ignore_eos=ignore_eos,
450453
extra_body=extra_body,
451454
response_format=response_format,
@@ -548,6 +551,7 @@ async def limited_request_func(request_func_input, pbar):
548551
"generated_texts": [output.generated_text for output in outputs],
549552
"reasoning_contents": [output.reasoning_content for output in outputs],
550553
"errors": [output.error for output in outputs],
554+
"metrics": [output.metrics for output in outputs],
551555
}
552556

553557
def process_one_metric(
@@ -583,6 +587,49 @@ def process_one_metric(
583587
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
584588
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
585589

590+
def process_pd_metrics(model_outputs, metric_key):
591+
# 收集所有该 metric 的数值
592+
values = []
593+
percentiles = []
594+
for p in args.metric_percentiles.split(","):
595+
p = p.strip()
596+
if p:
597+
percentiles.append(float(p))
598+
for item in model_outputs:
599+
metrics = item.metrics
600+
if metrics.get(metric_key, None) is not None:
601+
values.append(metrics[metric_key])
602+
603+
if not values:
604+
print(f"[WARN] metric_key '{metric_key}' not found in outputs.")
605+
return
606+
607+
arr = np.array(values) * 1000 # 秒 -> 毫秒
608+
609+
print("{s:{c}^{n}}".format(s=metric_key, n=50, c="-"))
610+
print(
611+
"{:<40} {:<10.2f}".format(
612+
f"Mean {metric_key} (ms):",
613+
np.mean(arr),
614+
)
615+
)
616+
print(
617+
"{:<40} {:<10.2f}".format(
618+
f"Median {metric_key} (ms):",
619+
np.median(arr),
620+
)
621+
)
622+
for p in percentiles:
623+
v = np.percentile(arr, p)
624+
print("{:<40} {:<10.2f}".format(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms):", v))
625+
# print(f"P{str(int(p)) if int(p) == p else str(p)} {metric_key} (ms): {v:10.2f}")
626+
print(
627+
"{:<40} {:<10.2f}".format(
628+
f"Successful {metric_key}:",
629+
len(arr),
630+
)
631+
)
632+
586633
def process_one_length(
587634
# E.g., "ttft"
588635
metric_attribute_name: str,
@@ -624,6 +671,19 @@ def process_one_length(
624671
process_one_metric("s_itl", "S_ITL", "Infer Inter-token Latency")
625672
process_one_metric("e2el", "E2EL", "End-to-end Latency")
626673
process_one_metric("s_e2el", "S_E2EL", "Infer End-to-end Latency")
674+
if any(item.metrics for item in outputs):
675+
process_pd_metrics(outputs, "prefill_cost_time")
676+
process_pd_metrics(outputs, "prefill_prepare_cost_time")
677+
process_pd_metrics(outputs, "preprocess_cost_time")
678+
process_pd_metrics(outputs, "cache_in_scheduler_cost_time")
679+
process_pd_metrics(outputs, "ask_decode_resource_cost_time")
680+
process_pd_metrics(outputs, "prefill_first_token_infer_cost_time")
681+
process_pd_metrics(outputs, "wait_sending_cache_cost_time")
682+
process_pd_metrics(outputs, "decode_preallocate_cost_time")
683+
process_pd_metrics(outputs, "decode_prepare_cost_time")
684+
process_pd_metrics(outputs, "decode_second_token_infer_cost_time")
685+
process_pd_metrics(outputs, "first_token_transmission_cost_time")
686+
process_pd_metrics(outputs, "second_token_transmission_cost_time")
627687
process_one_length("input_len", "Cached Tokens", "Cached Tokens")
628688
process_one_length("s_input_len", "Input Length", "Infer Input Length")
629689
process_one_length("output_len", "Output Length", "Output Length")
@@ -941,6 +1001,7 @@ def main(args: argparse.Namespace):
9411001
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
9421002
ignore_eos=args.ignore_eos,
9431003
debug=args.debug,
1004+
pd_metrics=args.pd_metrics,
9441005
goodput_config_dict=goodput_config_dict,
9451006
max_concurrency=args.max_concurrency,
9461007
lora_modules=args.lora_modules,
@@ -1129,6 +1190,11 @@ def main(args: argparse.Namespace):
11291190
action="store_true",
11301191
help="shuffle dataset",
11311192
)
1193+
parser.add_argument(
1194+
"--pd-metrics",
1195+
action="store_true",
1196+
help="请求时增加PD分离参数,metrics: True",
1197+
)
11321198
parser.add_argument(
11331199
"--drop-ratio",
11341200
type=float,

custom_ops/gpu_ops/append_attn/append_attention_func.cuh

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,7 @@ __global__ void merge_multi_chunks_v2_kernel(
24512451
if (bid == -1) {
24522452
continue;
24532453
}
2454+
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
24542455
const int seq_len_q = seq_lens_q[bid];
24552456
if (seq_len_q == 0) continue;
24562457
int seq_len_kv = seq_lens_kv[bid];
@@ -2494,14 +2495,32 @@ __global__ void merge_multi_chunks_v2_kernel(
24942495
}
24952496
#pragma unroll 2
24962497
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
2497-
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
2498+
uint32_t offset;
2499+
if (ENABLE_PREFILL) {
2500+
offset = (qid * num_chunks + i) * num_heads + hid;
2501+
} else {
2502+
offset =
2503+
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
2504+
i) *
2505+
num_heads +
2506+
hid;
2507+
}
24982508
float m_prev = m;
24992509
float d_prev = d;
25002510
const float m_now = multi_m[offset];
25012511
const float d_now = multi_d[offset];
25022512
m = max(m_prev, m_now);
2503-
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
2504-
vid * vec_size;
2513+
if (ENABLE_PREFILL) {
2514+
offset =
2515+
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
2516+
vid * vec_size;
2517+
} else {
2518+
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
2519+
num_chunks * num_heads +
2520+
i * num_heads + hid) *
2521+
head_dim +
2522+
vid * vec_size;
2523+
}
25052524
Load<T, vec_size>(&multi_out[offset], &load_vec);
25062525
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
25072526
const T scale1_T = static_cast<T>(scale1),

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
179179
const int num_rows_per_block,
180180
const int group_size) {
181181
// one block one warp
182-
const int lane_id = threadIdx.x % warpSize;
182+
const int lane_id = threadIdx.x % WARP_SIZE;
183183
int prev_offset = 0;
184184

185185
// loop on warp tile:[base, base+32)
186-
for (int base = 0; base < bsz; base += warpSize) {
186+
for (int base = 0; base < bsz; base += WARP_SIZE) {
187187
const int bid = base + lane_id;
188188

189189
// calculate loop_times for bid
@@ -199,13 +199,13 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
199199
// prefix sum for each lane, get the start offset in this tile
200200
// inclusive scan
201201
int x = loop_times;
202-
for (int offset = 1; offset < warpSize; offset <<= 1) {
202+
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
203203
int y = __shfl_up_sync(0xffffffff, x, offset);
204204
if (lane_id >= offset) x += y;
205205
}
206206
// exclusive prefix sum
207207
int bid_offset = x - loop_times;
208-
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
208+
int tile_sum = __shfl_sync(0xffffffff, x, WARP_SIZE - 1);
209209

210210
// write batch_ids and tile_ids_per_batch
211211
if (bid < bsz && loop_times > 0) {

0 commit comments

Comments
 (0)