Use MPReplayBuffer with Ray
Please see this document, too.
Ray is a OSS framework which enables users to build distributed application easily.
Ray can utilize multiple machines, so that Ray architecture doesn’t use shared memory except for immutable objects (Ref). However, if you use only a single machine, you might want to use shared memory for replay buffer to avoid expensive interprocess data sharing.
With cpprb 10.6+ and Python 3.8+, you can use MPReplayBuffer and
MPPrioritizedReplayBuffer with Ray.
A key trick is to set authkey inside Ray Actors, which allows the Ray
worker to communicate with SyncManager process.
import base64
import multiprocessing as mp
import ray
ray.init()
@ray.remote
class RemoteWorker:
# Encode base64 to avoid following error:
# TypeError: Pickling an AuthenticationString object is disallowed for security reasons
encoded = base64.b64encode(mp.current_process().authkey)
def __init__(self):
mp.current_process().authkey = base64.b64decode(self.encoded)
def run(self, some_resource):
pass
w = RemoteWorker.remote()This trick overwrites process-wide authkey, which might confilict if
you use other processes in it.
We also have to select SharedMemory backend and SyncManager
context. By this configuration, main data is placed on shared memory
and synchronization objects (e.g. Lock and Event) are accessed
through SyncManager proxy.
import multiprocessing as mp
from cpprb import MPReplayBuffer
buffer_size = 1e+6
m = mp.get_context().Manager() # SyncManager
rb = MPReplayBuffer(buffer_size,
{"done": {}},
ctx = m, # Context
backend="SharedMemory")In the end, (pseudo) example become like this;
# See: https://ymd_h.gitlab.io/cpprb/examples/mp_with_ray/
import base64
import multiprocessing as mp
import time
from cpprb import ReplayBuffer, MPPrioritizedReplayBuffer
import gym
import numpy as np
import ray
class Model:
def __init__(self, env):
self.env = env
self.w = None
def train(self, transitions):
"""
Update model weights and return |TD|
"""
absTD = np.zeros(shape=(transitions["obs"].shape[0],))
# omit
return absTD
def __call__(self, obs):
"""
Choose action from observation
"""
# omit
act = self.env.action_space.sample()
return act
@ray.remote
class Explorer:
encoded = base64.b64encode(mp.current_process().authkey)
def __init__(self):
# Set up 'authkey' to communicate with `SyncManager`.
# Important: Do not pass `MPReplayBuffer`, because it is not ready.
mp.current_process().authkey = base64.b64decode(self.encoded)
def run(self, env_name, global_rb, env_dict, q, stop):
try:
buffer_size = 200
local_rb = ReplayBuffer(buffer_size, env_dict)
env = gym.make(env_name)
model = Model(env)
reset = env.reset()
if not isinstance(reset, tuple):
# Gym Old API
obs = reset
else:
# Gym New API
obs, _ = reset
while True:
if stop.is_set():
print("Stop")
break
if not q.empty():
w = q.get()
model.w = w
act = model(obs)
stepped = env.step(act)
if len(stepped) == 4:
# Gym Old API
next_obs, rew, done, _ = stepped
else:
# Gym New API
next_obs, rew, term, trunc, _ = stepped
done = term | trunc
local_rb.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)
if done or local_rb.get_stored_size() == buffer_size:
local_rb.on_episode_end()
global_rb.add(**local_rb.get_all_transitions())
local_rb.clear()
reset = env.reset()
if not isinstance(reset, tuple):
# Gym Old API
obs = reset
else:
# Gym New API
obs, _ = reset
else:
obs = next_obs
finally:
stop.set()
return None
def run():
n_explorers = 4
nwarmup = 100
ntrain = int(1e+2)
update_freq = 100
env_name = "CartPole-v1"
env = gym.make(env_name)
buffer_size = 1e+6
env_dict = {"obs": {"shape": env.observation_space.shape},
"act": {},
"rew": {},
"next_obs": {"shape": env.observation_space.shape},
"done": {}}
alpha = 0.5
beta = 0.4
batch_size = 32
ray.init()
# `BaseContext.Manager()` automatically starts `SyncManager`
# Ref: https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/context.py#L49-L58
m = mp.get_context().Manager()
q = [m.Queue() for _ in range(n_explorers)]
stop = m.Event()
stop.clear()
rb = MPPrioritizedReplayBuffer(buffer_size, env_dict, alpha=alpha,
ctx=m, backend="SharedMemory")
model = Model(env)
explorers = []
jobs = []
print("Start Explorers")
for i in range(n_explorers):
explorers.append(Explorer.remote())
jobs.append(explorers[-1].run.remote(env_name, rb, env_dict, q[i], stop))
print("Start Warmup")
while rb.get_stored_size() < nwarmup and not stop.is_set():
time.sleep(1)
print("Start Training")
for i in range(ntrain):
if stop.is_set():
break
s = rb.sample(batch_size, beta)
absTD = model.train(s)
rb.update_priorities(s["indexes"], absTD)
if i % update_freq == 0:
q[i].put(model.w)
print("Finish Training")
stop.set()
_, still_running = ray.wait(jobs, timeout=10)
m.shutdown()
if __name__ == "__main__":
run()