Use MPReplayBuffer with Ray

Note

Please see this document, too.

Ray is a OSS framework which enables users to build distributed application easily.

Ray can utilize multiple machines, so that Ray architecture doesn’t use shared memory except for immutable objects (Ref). However, if you use only a single machine, you might want to use shared memory for replay buffer to avoid expensive interprocess data sharing.

With cpprb 10.6+ and Python 3.8+, you can use MPReplayBuffer and MPPrioritizedReplayBuffer with Ray.

A key trick is to set authkey inside Ray Actors, which allows the Ray worker to communicate with SyncManager process.

import base64
import multiprocessing as mp

import ray

ray.init()

@ray.remote
class RemoteWorker:
    # Encode base64 to avoid following error:
    #   TypeError: Pickling an AuthenticationString object is disallowed for security reasons
    encoded = base64.b64encode(mp.current_process().authkey)

    def __init__(self):
        mp.current_process().authkey = base64.b64decode(self.encoded)

    def run(self, some_resource):
        pass

w = RemoteWorker.remote()
Warning

This trick overwrites process-wide authkey, which might confilict if you use other processes in it.

We also have to select SharedMemory backend and SyncManager context. By this configuration, main data is placed on shared memory and synchronization objects (e.g. Lock and Event) are accessed through SyncManager proxy.

import multiprocessing as mp

from cpprb import MPReplayBuffer

buffer_size = 1e+6
m = mp.get_context().Manager() # SyncManager

rb = MPReplayBuffer(buffer_size,
                    {"done": {}},
                    ctx = m, # Context
                    backend="SharedMemory")

In the end, (pseudo) example become like this;

# See: https://ymd_h.gitlab.io/cpprb/examples/mp_with_ray/

import base64
import multiprocessing as mp
import time

from cpprb import ReplayBuffer, MPPrioritizedReplayBuffer
import gym
import numpy as np
import ray


class Model:
    def __init__(self, env):
        self.env = env
        self.w = None

    def train(self, transitions):
        """
        Update model weights and return |TD|
        """
        absTD = np.zeros(shape=(transitions["obs"].shape[0],))
        # omit
        return absTD

    def __call__(self, obs):
        """
        Choose action from observation
        """
        # omit
        act = self.env.action_space.sample()
        return act

@ray.remote
class Explorer:
    encoded = base64.b64encode(mp.current_process().authkey)

    def __init__(self):
        # Set up 'authkey' to communicate with `SyncManager`.
        # Important: Do not pass `MPReplayBuffer`, because it is not ready.
        mp.current_process().authkey = base64.b64decode(self.encoded)

    def run(self, env_name, global_rb, env_dict, q, stop):
        try:
            buffer_size = 200

            local_rb = ReplayBuffer(buffer_size, env_dict)
            env = gym.make(env_name)
            model = Model(env)

            reset = env.reset()
            if not isinstance(reset, tuple):
                # Gym Old API
                obs = reset
            else:
                # Gym New API
                obs, _ = reset

            while True:
                if stop.is_set():
                    print("Stop")
                    break
                if not q.empty():
                    w = q.get()
                    model.w = w

                act = model(obs)

                stepped = env.step(act)
                if len(stepped) == 4:
                    # Gym Old API
                    next_obs, rew, done, _ = stepped
                else:
                    # Gym New API
                    next_obs, rew, term, trunc, _ = stepped
                    done = term | trunc

                local_rb.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)

                if done or local_rb.get_stored_size() == buffer_size:
                    local_rb.on_episode_end()
                    global_rb.add(**local_rb.get_all_transitions())
                    local_rb.clear()
                    reset = env.reset()
                    if not isinstance(reset, tuple):
                        # Gym Old API
                        obs = reset
                    else:
                        # Gym New API
                        obs, _ = reset
                else:
                    obs = next_obs
        finally:
            stop.set()
        return None


def run():
    n_explorers = 4

    nwarmup = 100
    ntrain = int(1e+2)
    update_freq = 100
    env_name = "CartPole-v1"

    env = gym.make(env_name)

    buffer_size = 1e+6
    env_dict = {"obs": {"shape": env.observation_space.shape},
                "act": {},
                "rew": {},
                "next_obs": {"shape": env.observation_space.shape},
                "done": {}}
    alpha = 0.5
    beta = 0.4
    batch_size = 32


    ray.init()


    # `BaseContext.Manager()` automatically starts `SyncManager`
    # Ref: https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/context.py#L49-L58
    m = mp.get_context().Manager()
    q = [m.Queue() for _ in range(n_explorers)]
    stop = m.Event()
    stop.clear()

    rb = MPPrioritizedReplayBuffer(buffer_size, env_dict, alpha=alpha,
                                   ctx=m, backend="SharedMemory")

    model = Model(env)

    explorers = []
    jobs = []

    print("Start Explorers")
    for i in range(n_explorers):
        explorers.append(Explorer.remote())
        jobs.append(explorers[-1].run.remote(env_name, rb, env_dict, q[i], stop))


    print("Start Warmup")
    while rb.get_stored_size() < nwarmup and not stop.is_set():
        time.sleep(1)

    print("Start Training")
    for i in range(ntrain):
        if stop.is_set():
            break

        s = rb.sample(batch_size, beta)
        absTD = model.train(s)
        rb.update_priorities(s["indexes"], absTD)

        if i % update_freq == 0:
            q[i].put(model.w)

    print("Finish Training")

    stop.set()
    _, still_running = ray.wait(jobs, timeout=10)

    m.shutdown()

if __name__ == "__main__":
    run()