Skip to content

Commit 7862372

Browse files
authored
[https://nvbugs/5552889][fix] fix: Prevent empty batch when using attention DP with disagg (NVIDIA#8372)
Signed-off-by: Patrice Castonguay <[email protected]>
1 parent 27c6c84 commit 7862372

File tree

4 files changed

+145
-20
lines changed

4 files changed

+145
-20
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -796,12 +796,7 @@ def _executor_loop_pp(self):
796796
f'{len(scheduled_batch.generation_requests)} generation requests'
797797
)
798798

799-
if self.enable_attention_dp:
800-
tp_batch_sizes = self.dist.tp_allgather(
801-
scheduled_batch.batch_size)
802-
can_queue = 0 not in tp_batch_sizes
803-
else:
804-
can_queue = scheduled_batch.batch_size > 0
799+
can_queue = self._can_queue(scheduled_batch)
805800

806801
if not can_queue:
807802
self.micro_batches[microbatch_id] = None
@@ -948,6 +943,16 @@ def wait_on_pp_send_handles(self, microbatch_id):
948943
self.send_handles[microbatch_id].wait()
949944
self.send_handles[microbatch_id] = None
950945

946+
def _can_queue(self, scheduled_batch):
947+
948+
if self.enable_attention_dp:
949+
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
950+
can_queue = 0 not in tp_batch_sizes
951+
else:
952+
can_queue = scheduled_batch.batch_size > 0
953+
954+
return can_queue
955+
951956
def _prepare_and_schedule_batch(self):
952957
new_requests = self._fetch_and_activate_new_requests()
953958
if self.should_stop_processing:
@@ -1052,8 +1057,8 @@ def _executor_loop(self):
10521057

10531058
finished_requests = []
10541059

1055-
if scheduled_batch.batch_size > 0 or (
1056-
self.enable_attention_dp and self.dist.tp_size > 1):
1060+
can_queue = self._can_queue(scheduled_batch)
1061+
if can_queue:
10571062
if self.kv_cache_transceiver:
10581063
# For generation requests which have completed KV cache transfer
10591064
self._prepare_disagg_gen_transmission_complete(
@@ -1065,8 +1070,11 @@ def _executor_loop(self):
10651070

10661071
self._kv_connector_start_batch(scheduled_batch)
10671072

1068-
if scheduled_batch.batch_size > 0 or (
1069-
self.enable_attention_dp and self.dist.tp_size > 1):
1073+
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
1074+
if self.kv_connector_manager:
1075+
can_queue = self._can_queue(scheduled_batch)
1076+
1077+
if can_queue:
10701078
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
10711079
# init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated.
10721080
if self.guided_decoder is not None:
@@ -1180,7 +1188,8 @@ def _executor_loop_overlap(self):
11801188

11811189
self._pause_requests(scheduled_batch.paused_requests)
11821190

1183-
if scheduled_batch.batch_size > 0:
1191+
can_queue = self._can_queue(scheduled_batch)
1192+
if can_queue:
11841193
if self.kv_cache_transceiver:
11851194
# For generation requests which have completed KV cache transfer
11861195
self._prepare_disagg_gen_transmission_complete(
@@ -1189,7 +1198,11 @@ def _executor_loop_overlap(self):
11891198

11901199
self._kv_connector_start_batch(scheduled_batch)
11911200

1192-
if scheduled_batch.batch_size > 0:
1201+
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
1202+
if self.kv_connector_manager:
1203+
can_queue = self._can_queue(scheduled_batch)
1204+
1205+
if can_queue:
11931206

11941207
# The generation requests that are do not have batch_idx,
11951208
# needs to be in front of the batch due to the assumptions
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
hostname: localhost
2+
port: 8000
3+
model: DeepSeek-V3-Lite/bf16
4+
backend: "pytorch"
5+
context_servers:
6+
num_instances: 1
7+
build_config:
8+
max_batch_size: 10
9+
max_num_tokens: 512
10+
max_seq_len: 768
11+
max_batch_size: 10
12+
max_num_tokens: 512
13+
max_seq_len: 768
14+
tensor_parallel_size: 2
15+
moe_expert_parallel_size: 2
16+
enable_attention_dp: true
17+
pipeline_parallel_size: 1
18+
print_iter_log: true
19+
cuda_graph_config: null
20+
disable_overlap_scheduler: true
21+
kv_cache_config:
22+
enable_block_reuse: false
23+
free_gpu_memory_fraction: 0.05
24+
max_tokens: 512
25+
cache_transceiver_config:
26+
max_tokens_in_buffer: 8448
27+
backend: DEFAULT
28+
urls:
29+
- "localhost:8001"
30+
generation_servers:
31+
num_instances: 1
32+
build_config:
33+
max_batch_size: 1
34+
max_num_tokens: 2048
35+
max_seq_len: 2560
36+
tensor_parallel_size: 1
37+
moe_expert_parallel_size: 1
38+
enable_attention_dp: false
39+
enable_lm_head_tp_in_adp: false
40+
pipeline_parallel_size: 1
41+
max_batch_size: 1
42+
max_num_tokens: 2048
43+
max_seq_len: 2560
44+
cuda_graph_config:
45+
enable_padding: true
46+
batch_sizes:
47+
- 1
48+
print_iter_log: true
49+
kv_cache_config:
50+
enable_block_reuse: false
51+
free_gpu_memory_fraction: 0.7
52+
max_tokens: 2560
53+
moe_config:
54+
backend: CUTLASS
55+
cache_transceiver_config:
56+
max_tokens_in_buffer: 8448
57+
backend: DEFAULT
58+
stream_interval: 1
59+
num_postprocess_workers: 1
60+
urls:
61+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ def get_test_config(test_desc, example_dir, test_root):
155155
(4,
156156
f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml"
157157
),
158+
"deepseek_v3_lite_bf16_empty_batch":
159+
(3,
160+
f"{test_configs_root}/disagg_config_deepseek_v3_lite_empty_batch.yaml"
161+
),
158162
}
159163

160164
if test_desc not in config_map:
@@ -1280,14 +1284,19 @@ def run_disaggregated_benchmark(example_dir,
12801284
benchmark_model_root,
12811285
shared_gpt_path,
12821286
env=None,
1283-
cwd=None):
1287+
cwd=None,
1288+
num_ranks=2,
1289+
random_input_len=16,
1290+
random_output_len=64,
1291+
num_prompts=100,
1292+
max_concurrency=32,
1293+
skip_warmup=False):
12841294
"""Run disaggregated test with given configuration."""
12851295
run_env = env.copy()
12861296
run_env["UCX_TLS"] = "^ib"
1287-
num_rank = 2
12881297
workers_cmd = [
12891298
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
1290-
str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
1299+
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
12911300
config_file
12921301
]
12931302

@@ -1339,15 +1348,15 @@ def run_disaggregated_benchmark(example_dir,
13391348
'--dataset-path',
13401349
shared_gpt_path,
13411350
'--random-input-len',
1342-
'256',
1351+
str(random_input_len),
13431352
'--random-output-len',
1344-
'64',
1353+
str(random_output_len),
13451354
'--random-prefix-len',
13461355
'0',
13471356
'--num-prompts',
1348-
'320',
1357+
str(num_prompts),
13491358
'--max-concurrency',
1350-
'32',
1359+
str(max_concurrency),
13511360
'--host',
13521361
'localhost',
13531362
'--port',
@@ -1358,7 +1367,8 @@ def run_disaggregated_benchmark(example_dir,
13581367
'e2el,ttft',
13591368
]
13601369
# warm up
1361-
check_call(benchmark_cmd, env=env)
1370+
if not skip_warmup:
1371+
check_call(benchmark_cmd, env=env)
13621372
output = check_output(benchmark_cmd, env=env)
13631373
e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"
13641374
ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"
@@ -1468,3 +1478,43 @@ def test_disaggregated_benchmark_on_diff_backends(
14681478

14691479
assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el
14701480
assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft
1481+
1482+
1483+
@pytest.mark.parametrize("benchmark_model_root", ['DeepSeek-V3-Lite-bf16'],
1484+
indirect=True)
1485+
def test_disaggregated_deepseek_v3_lite_bf16_empty_batch(
1486+
disaggregated_example_root, llm_venv, benchmark_model_root,
1487+
benchmark_root, shared_gpt_path):
1488+
1489+
src_dst_dict = {
1490+
benchmark_model_root:
1491+
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
1492+
}
1493+
for src, dst in src_dst_dict.items():
1494+
if not os.path.islink(dst):
1495+
os.makedirs(os.path.dirname(dst), exist_ok=True)
1496+
os.symlink(src, dst, target_is_directory=True)
1497+
1498+
test_desc = "deepseek_v3_lite_bf16_empty_batch"
1499+
num_ranks, config_file = get_test_config(test_desc,
1500+
disaggregated_example_root,
1501+
os.path.dirname(__file__))
1502+
1503+
env = llm_venv._new_env.copy()
1504+
e2el, ttft = run_disaggregated_benchmark(
1505+
disaggregated_example_root,
1506+
config_file,
1507+
benchmark_root,
1508+
benchmark_model_root,
1509+
shared_gpt_path,
1510+
env=env,
1511+
cwd=llm_venv.get_working_directory(),
1512+
num_ranks=num_ranks,
1513+
num_prompts=10,
1514+
max_concurrency=10,
1515+
random_input_len=384,
1516+
random_output_len=1536,
1517+
skip_warmup=True)
1518+
print(f"E2EL: {e2el} ms, TTFT: {ttft} ms")
1519+
1520+
assert e2el > 0 and ttft > 0

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ l0_dgx_h100:
133133
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8]
134134
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16]
135135
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16]
136+
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16]
136137
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
137138
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8]
138139
- disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16]

0 commit comments

Comments
 (0)