Skip to content

Commit 3ed230f

Browse files
Fix local model support in VERL (#299) (#300)
(cherry picked from commit f2869ce) Co-authored-by: Yuge Zhang <Yuge.Zhang@microsoft.com>
1 parent 1d2515c commit 3ed230f

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

.github/workflows/examples-calc-x.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,22 @@ jobs:
137137
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
138138
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
139139

140+
- name: Calc-X training with local model
141+
run: |
142+
set -ex
143+
source .venv/bin/activate
144+
cd examples/calc_x
145+
../../scripts/restart_ray.sh
146+
sleep 5
147+
hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir data/qwen_model
148+
PYTHONUNBUFFERED=1 python train_calc_agent.py --val-file data/test_mini.parquet --ci --model $(realpath data/qwen_model)
149+
sleep 10
150+
shell: bash
151+
env:
152+
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
153+
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
154+
id: calc_x_train_local_model
155+
140156
- name: Calc-X training LLM Proxy
141157
run: |
142158
set -ex

agentlightning/verl/trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15+
import verl
1516
from codetiming import Timer
1617
from omegaconf import OmegaConf
1718
from tqdm import tqdm
@@ -298,14 +299,20 @@ def fit(self):
298299
assert self.async_rollout_mode, "If agent mode is enabled, async server must be enabled"
299300
if self.adapter is not None and not isinstance(self.adapter, TraceToTripletBase):
300301
raise ValueError("Adapter must be a TraceToTripletBase for currently VERL implementation.")
302+
verl_version = verl.__version__
303+
if verl_version == "0.5.0":
304+
# Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
305+
# However, it is possible that verl updates the naming and causes incompatibility.
306+
# Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
307+
model = "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:])
308+
else:
309+
# For other versions (e.g., 0.6.0), we use the full path to the model.
310+
model = self.config.actor_rollout_ref.model.path
301311
self.agent_mode_daemon = AgentModeDaemon(
302312
self.config.agentlightning.port,
303313
self.config.actor_rollout_ref.rollout.n,
304314
train_information={
305-
# Note (Zhiyuan): To avoid further patch into vllm async server, using the same sentence to get the naming here.
306-
# However, it is possible that verl updates the naming and causes incompatibility.
307-
# Reference: https://github.com/volcengine/verl/blob/5b5e09d9cc20625e436d01f69d9cc739ff681c54/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L217
308-
"model": "/".join(self.config.actor_rollout_ref.model.path.split("/")[-2:]),
315+
"model": model,
309316
"temperature": self.config.actor_rollout_ref.rollout.temperature,
310317
},
311318
tokenizer=self.tokenizer,

0 commit comments

Comments
 (0)