-
Notifications
You must be signed in to change notification settings - Fork 302
Expand file tree
/
Copy pathmain_fully_async.py
More file actions
61 lines (49 loc) · 1.44 KB
/
main_fully_async.py
File metadata and controls
61 lines (49 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Main entrypoint for fully async training.
"""
import sys
from skyrl.train.config import SkyRLTrainConfig
from skyrl.train.entrypoints.main_base import BasePPOExp, validate_cfg
from skyrl.train.fully_async_trainer import FullyAsyncRayPPOTrainer
import asyncio
from skyrl.train.utils import initialize_ray
import ray
class FullyAsyncPPOExp(BasePPOExp):
def get_trainer(
self,
cfg,
tracker,
tokenizer,
train_dataset,
eval_dataset,
inference_engine_client,
generator,
colocate_pg,
):
return FullyAsyncRayPPOTrainer(
cfg=cfg,
tracker=tracker,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
inference_engine_client=inference_engine_client,
generator=generator,
colocate_pg=colocate_pg,
)
def run(self):
trainer = self._setup_trainer()
# Start the async training loop
asyncio.run(trainer.train())
@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: SkyRLTrainConfig):
# make sure that the training loop is not run on the head node.
exp = FullyAsyncPPOExp(cfg)
exp.run()
def main() -> None:
cfg = SkyRLTrainConfig.from_cli_overrides(sys.argv[1:])
# validate the arguments
validate_cfg(cfg)
initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))
if __name__ == "__main__":
main()