Skip to content

Commit 2b3cc41

Browse files
authored
Fix GPU pipeline issues (#17)
1 parent c51bc63 commit 2b3cc41

File tree

6 files changed

+79
-11
lines changed

6 files changed

+79
-11
lines changed

.github/workflows/examples.yml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jobs:
3939
- name: Upload dependencies artifact
4040
uses: actions/upload-artifact@v4
4141
with:
42-
name: dependencies-python
42+
name: dependencies-${{ matrix.setup }}
4343
path: requirements-freeze.txt
4444
compression-level: 0
4545
- name: Prepare Spider dataset
@@ -97,6 +97,7 @@ jobs:
9797
source .venv/bin/activate
9898
cd examples/calc_x
9999
../../scripts/restart_ray.sh
100+
sleep 5
100101
PYTHONUNBUFFERED=1 python calc_agent.py &
101102
bash train_ci.sh
102103
pkill -f calc_agent.py && echo "SIGTERM sent to calc_agent.py" || echo "No calc_agent.py process found"
@@ -105,15 +106,27 @@ jobs:
105106
sleep 5
106107
done
107108
echo "calc_agent.py has finished."
109+
sleep 10
108110
shell: bash
109111
env:
110112
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
113+
id: calc_x_train
114+
115+
- name: Validate Calc-X training
116+
run: |
117+
set -ex
118+
. .venv/bin/activate
119+
python scripts/validate_example_wandb.py ${{ steps.calc_x_train.outputs.project_name }} ${{ steps.calc_x_train.outputs.run_name }}
120+
env:
121+
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
122+
111123
- name: Spider training
112124
run: |
113125
set -ex
114126
source .venv/bin/activate
115127
cd examples/spider
116128
../../scripts/restart_ray.sh
129+
sleep 5
117130
PYTHONUNBUFFERED=1 python sql_agent.py --trainer.n-workers 10 &
118131
bash train_ci.sh
119132
pkill -f sql_agent.py && echo "SIGTERM sent to sql_agent.py" || echo "No sql_agent.py process found"
@@ -122,11 +135,22 @@ jobs:
122135
sleep 5
123136
done
124137
echo "sql_agent.py has finished."
138+
sleep 10
125139
shell: bash
126140
env:
127141
VERL_API_BASE: http://localhost:9991/
128142
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
143+
id: spider_train
129144
if: success() || failure()
145+
146+
- name: Validate Spider training
147+
run: |
148+
set -ex
149+
. .venv/bin/activate
150+
python scripts/validate_example_wandb.py ${{ steps.spider_train.outputs.project_name }} ${{ steps.spider_train.outputs.run_name }}
151+
env:
152+
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
153+
130154
- name: Cleanup
131155
run: ./scripts/cleanup.sh
132156
if: success() || failure()

agentlightning/verl/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ def fit(self):
336336
if is_last_step:
337337
pprint(f"Final validation metrics: {last_val_metrics}")
338338
progress_bar.close()
339+
340+
# This exit logic is to ensure a robust CI.
341+
pprint(f"Flush the logger...")
342+
del logger # Make sure the loggers are flushed and closed properly
343+
pprint(f"Training finished at step {self.global_steps}.")
339344
return
340345

341346
progress_bar.update(1)

examples/calc_x/train_ci.sh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
21
#!/bin/bash
32

4-
set -e
3+
set -ex
54

65
export N_GPUS=1
76
export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct
87
export DATA_DIR=data
98
export ROLLOUT_TP_SIZE=1
109
export EXPERIMENT_NAME="calc_x_$(date +%Y%m%d%H%M%S)"
1110
export PROJECT_NAME=AgentLightningCI
11+
echo "project_name=${PROJECT_NAME}" >> $GITHUB_OUTPUT
12+
echo "run_name=${EXPERIMENT_NAME}" >> $GITHUB_OUTPUT
1213

1314
PYTHONUNBUFFERED=1 python -m agentlightning.verl \
1415
algorithm.adv_estimator=grpo \
@@ -48,6 +49,6 @@ PYTHONUNBUFFERED=1 python -m agentlightning.verl \
4849
trainer.experiment_name=${EXPERIMENT_NAME} \
4950
trainer.nnodes=1 \
5051
trainer.save_freq=256 \
51-
trainer.test_freq=3 \
52+
trainer.test_freq=6 \
5253
trainer.total_epochs=1 \
53-
trainer.total_training_steps=3 $@
54+
trainer.total_training_steps=6 $@

examples/spider/train_ci.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#!/bin/bash
22

3-
set -e
3+
set -ex
44

55
export N_GPUS=1
66
export BASE_MODEL=Qwen/Qwen2.5-Coder-1.5B-Instruct
77
export DATA_DIR=data
88
export ROLLOUT_TP_SIZE=1
99
export EXPERIMENT_NAME="spider_$(date +%Y%m%d%H%M%S)"
10-
export PROJECT_NAME=AgentLightning
10+
export PROJECT_NAME=AgentLightningCI
11+
echo "project_name=${PROJECT_NAME}" >> $GITHUB_OUTPUT
12+
echo "run_name=${EXPERIMENT_NAME}" >> $GITHUB_OUTPUT
1113

12-
echo "Starting training script..."
13-
14-
python -m agentlightning.verl \
14+
PYTHONUNBUFFERED=1 python -m agentlightning.verl \
1515
agentlightning.port=9991 \
1616
algorithm.adv_estimator=grpo \
1717
data.train_files=${DATA_DIR}/train_spider.parquet \

scripts/restart_ray.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/bin/bash
22

3-
ray stop
3+
ray stop --force
44
env RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0

scripts/validate_example_wandb.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import wandb
2+
import sys
3+
4+
if len(sys.argv) != 3:
5+
print("Usage: python validate_example_wandb.py <project> <run_name>")
6+
7+
project = sys.argv[1]
8+
run_name = sys.argv[2]
9+
api = wandb.Api()
10+
entity_name = api.default_entity
11+
print("Default entity:", entity_name)
12+
print("Project:", project)
13+
print("Run name:", run_name)
14+
15+
runs = api.runs(f"{entity_name}/{project}", filters={"displayName": run_name})
16+
for run in runs:
17+
print(f"Found run: {run.name} (ID: {run.id})")
18+
if run.name == run_name:
19+
break
20+
else:
21+
print(f"::error::Run with name '{run_name}' not found in project '{project}'.")
22+
sys.exit(1)
23+
24+
hist = run.history(keys=["val/reward"], pandas=True)
25+
print("History:", hist)
26+
if hist.empty:
27+
print("::error::No history found for the run.")
28+
sys.exit(1)
29+
else:
30+
first, last = hist["val/reward"].iloc[0], hist["val/reward"].iloc[-1]
31+
if last <= first:
32+
print(
33+
f"::warning title=Training no improvement::No improvement (run_name={run_name} start={first:.4f}, end={last:.4f})"
34+
)
35+
else:
36+
print(
37+
f"::notice title=Training completed::Run has improved (run_name={run_name} start={first:.4f}, end={last:.4f})"
38+
)

0 commit comments

Comments
 (0)