Ray Support

1. Goal

We can pass MPReplayBuffer and MPPrioritizedReplayBuffer to functions wrapped with @ray.remote.

Ref: https://github.com/ymd-h/cpprb/discussions/17

2. Problem

MPReplayBuffer / MPPrioritizedReplayBuffer utilizes shared memory, which cannot be serialized by Ray.

from cpprb import MPPrioritizedReplayBuffer
import ray

rb = MPPrioritizedReplayBuffer(4, {"done": {}})
ray.util.inspect_serializability(rb, name="MPPrioritizedReplayBuffer")
================================================================================
Checking Serializability of <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f41c4ea6040>
================================================================================
!!! FAIL serialization: cannot pickle 'mmap.mmap' object
WARNING: Did not find non-serializable object in <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f41c4ea6040>. This may be an oversight.
================================================================================

3. Tasks

To work with Ray, we have to manage following multiprocessing objects;

  • RawValue / RawArray
    • For Python 3.8+, SharedMemory (doc) can be used.
  • Value (aka. RawValue + RLock)
  • Event
    • asyncio.Event is possible candidate (Ref)
    • SyncManager.Event
  • Lock
    • SyncManager.Lock

3.1 🆗 SharedMemory for RawValue and RawArray

The following code can work. SharedMemory has custom __reduce__ method (Ref) and can be passed to @ray.remote.

from multiprocessing.shared_memory import SharedMemory
import numpy as np
import ray

def ray_test():
    ray.init()

    shm = SharedMemory(create=True, size=32 * 3)
    a = np.ndarray(shape=(3,), dtype=np.int32, buffer=shm.buf)
    print(a)

    @ray.remote
    def add(name, shape, dtype):
        m = SharedMemory(name=name)
        b = np.ndarray(shape=shape, dtype=dtype, buffer=m.buf)
        b += 2

    @ray.remote
    def add_shm(shm, shape, dtype):
        b = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
        b += 2

    ray.get(add.remote(shm.name, a.shape, a.dtype))
    print(a)

    ray.get(add_shm.remote(shm, a.shape, a.dtype))
    print(a)

    shm.close()
    shm.unlink()

if __name__ == "__main__":
    ray_test()

3.2 🆖 asyncio.Event for Event

In the Ray example, asyncio.Event is used, but it is not shared. Through @ray.remote, the Event is changed remotely.
This can work with Ray, but work only with Ray.

Additionally, this requires 1-on-1 architecture.
In MPReplayBuffer, each Event object is shared by 1 Learners and all Explorers.

3.3 🆗 SyncManager.Lock for Lock and SyncManager.Event for Event

As proposed at the discussion, proxy-based Lock and Event can be used with Ray.

The trick is setting authkey in Ray workers beforehand.

import base64
import multiprocessing as mp
from multiprocessing.managers import SyncManager

import ray

def run():
    ray.init()

    m = SyncManager()
    m.start()

    v = m.Value(ctypes.c_int, 0)
    lock = m.Lock()
    authkey = mp.current_process().authkey

    # Encode base64 to avoid following error:
    #   TypeError: Pickling an AuthenticationString object is disallowed for security reasons
    encoded = base64.b64encode(authkey)

    def auth_fn(*args):
        mp.current_process().authkey = base64.b64decode(encoded)

    ray.worker.global_worker.run_function_on_all_workers(auth_fn)

    @ray.remote
    def remote_lock(v, L):
        with L:
            v.value += 1

    try:
        print(v.value) # -> 0
        ray.get([remote_lock.remote(v, lock),
                 remote_lock.remote(v, lock)])
        print(v.value) # -> 2
    finally:
        m.shutdown()

if __name__ == "__main__":
    run()

Usually only when the method is spawn, authkey is serialized (ref).
If no authkey is passed, one for the current process is used. (ref)

4. Implementation of multiprocessing

(A part of) Class Diagram

classDiagram
     MPReplayBuffer *-- SharedBuffer
     SharedBuffer *-- RawArray
     MPReplayBuffer *-- ProcessSafeRingBufferIndex
     MPReplayBuffer *-- Value
     MPReplayBuffer *-- Event
     ProcessSafeRingBufferIndex *-- RingBufferIndex
     ProcessSafeRingBufferIndex *-- Lock
     RingBufferIndex *-- RawValue
     Value *-- RawValue
     Value *-- RLock
     Event *-- Condition
     Event *-- Semaphore
     Condition *-- Semaphore
     Condition *-- Lock
     SemLock <|-- Semaphore
     SemLock <|-- Lock
     SemLock <|-- RLock
     SemLock *-- _multiprocessing_SemLock
     SemLock: + __getstate__(self)
     SemLock: + __setstate__(self, state)
     class _multiprocessing_SemLock {
         <<Implemented in C>>
         + _rebuild(handle, kind, maxvalue, name)
     }

Notes

  • SemLock.__getstate__ (ref) raises RuntimeError for usual pickling.
  • Precisely speaking, Value, RawValue, and RawArray are not classes but helper functions to create sharedctypes objects.
  • Because of Mermaid diagram's syntax limitation, _multiprocessing.SemLock is showed as _multiprocessing_SemLock.
  • SemLock.__setstate__ calls _multiprocessing.SemLock._rebuild to remap semaphore in the new process. (ref)
Edited by Yamada Hiroyuki