Ray Support
1. Goal
We can pass MPReplayBuffer and MPPrioritizedReplayBuffer to functions wrapped with @ray.remote.
2. Problem
MPReplayBuffer / MPPrioritizedReplayBuffer utilizes shared memory, which cannot be serialized by Ray.
from cpprb import MPPrioritizedReplayBuffer
import ray
rb = MPPrioritizedReplayBuffer(4, {"done": {}})
ray.util.inspect_serializability(rb, name="MPPrioritizedReplayBuffer")
================================================================================
Checking Serializability of <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f41c4ea6040>
================================================================================
!!! FAIL serialization: cannot pickle 'mmap.mmap' object
WARNING: Did not find non-serializable object in <cpprb.PyReplayBuffer.MPPrioritizedReplayBuffer object at 0x7f41c4ea6040>. This may be an oversight.
================================================================================
3. Tasks
To work with Ray, we have to manage following multiprocessing objects;
-
RawValue/RawArray- For Python 3.8+,
SharedMemory(doc) can be used.
- For Python 3.8+,
-
Value(aka.RawValue+RLock) -
Eventasyncio.Eventis possible candidate (Ref)SyncManager.Event
-
LockSyncManager.Lock
3.1 🆗 SharedMemory for RawValue and RawArray
The following code can work. SharedMemory has custom __reduce__ method (Ref) and can be passed to @ray.remote.
from multiprocessing.shared_memory import SharedMemory
import numpy as np
import ray
def ray_test():
ray.init()
shm = SharedMemory(create=True, size=32 * 3)
a = np.ndarray(shape=(3,), dtype=np.int32, buffer=shm.buf)
print(a)
@ray.remote
def add(name, shape, dtype):
m = SharedMemory(name=name)
b = np.ndarray(shape=shape, dtype=dtype, buffer=m.buf)
b += 2
@ray.remote
def add_shm(shm, shape, dtype):
b = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf)
b += 2
ray.get(add.remote(shm.name, a.shape, a.dtype))
print(a)
ray.get(add_shm.remote(shm, a.shape, a.dtype))
print(a)
shm.close()
shm.unlink()
if __name__ == "__main__":
ray_test()
3.2 🆖 asyncio.Event for Event
In the Ray example, asyncio.Event is used, but it is not shared. Through @ray.remote, the Event is changed remotely.
This can work with Ray, but work only with Ray.
Additionally, this requires 1-on-1 architecture.
In MPReplayBuffer, each Event object is shared by 1 Learners and all Explorers.
3.3 🆗 SyncManager.Lock for Lock and SyncManager.Event for Event
As proposed at the discussion, proxy-based Lock and Event can be used with Ray.
The trick is setting authkey in Ray workers beforehand.
import base64
import multiprocessing as mp
from multiprocessing.managers import SyncManager
import ray
def run():
ray.init()
m = SyncManager()
m.start()
v = m.Value(ctypes.c_int, 0)
lock = m.Lock()
authkey = mp.current_process().authkey
# Encode base64 to avoid following error:
# TypeError: Pickling an AuthenticationString object is disallowed for security reasons
encoded = base64.b64encode(authkey)
def auth_fn(*args):
mp.current_process().authkey = base64.b64decode(encoded)
ray.worker.global_worker.run_function_on_all_workers(auth_fn)
@ray.remote
def remote_lock(v, L):
with L:
v.value += 1
try:
print(v.value) # -> 0
ray.get([remote_lock.remote(v, lock),
remote_lock.remote(v, lock)])
print(v.value) # -> 2
finally:
m.shutdown()
if __name__ == "__main__":
run()
Usually only when the method is spawn, authkey is serialized (ref).
If no authkey is passed, one for the current process is used. (ref)
4. Implementation of multiprocessing
(A part of) Class Diagram
classDiagram
MPReplayBuffer *-- SharedBuffer
SharedBuffer *-- RawArray
MPReplayBuffer *-- ProcessSafeRingBufferIndex
MPReplayBuffer *-- Value
MPReplayBuffer *-- Event
ProcessSafeRingBufferIndex *-- RingBufferIndex
ProcessSafeRingBufferIndex *-- Lock
RingBufferIndex *-- RawValue
Value *-- RawValue
Value *-- RLock
Event *-- Condition
Event *-- Semaphore
Condition *-- Semaphore
Condition *-- Lock
SemLock <|-- Semaphore
SemLock <|-- Lock
SemLock <|-- RLock
SemLock *-- _multiprocessing_SemLock
SemLock: + __getstate__(self)
SemLock: + __setstate__(self, state)
class _multiprocessing_SemLock {
<<Implemented in C>>
+ _rebuild(handle, kind, maxvalue, name)
}
Notes
-
SemLock.__getstate__(ref) raisesRuntimeErrorfor usual pickling. - Precisely speaking,
Value,RawValue, andRawArrayare not classes but helper functions to createsharedctypesobjects. - Because of Mermaid diagram's syntax limitation,
_multiprocessing.SemLockis showed as_multiprocessing_SemLock. -
SemLock.__setstate__calls_multiprocessing.SemLock._rebuildto remap semaphore in the new process. (ref)