Large Batch Experience Replay
1 Overview
Large Batch Experience Replay (LaBER) was proposed by T. Lahire et al.1
The authors theoretically derived the best sampling probability \( p_i ^{\ast} \) to minimize performance variance;
\[ p_i^{\ast} \propto \| \nabla _{\theta} L(Q_{\theta}(s_i, a_i), y_i) \| \text{,}\]
where \(L(\cdot,\, \cdot)\) is a loss function.
This requires full backpropagation and is costful, so that the authors proposed surrogate priority \(\hat{p}_i \propto \| \partial L(q_i, y_i) / \partial q_i \| \).
Since \( \| \nabla _{\theta} L(Q_{\theta}(s_i, a_i), y_i) \| = \| \partial L(q_i, y_i) / \partial q_i \cdot \nabla _{\theta} Q_{\theta}(s_i, a_i) \| \), the surrogate priority is good approximation when \( \nabla _{\theta} Q_{\theta}(s_i, a_i) \) is almost constant across samples.
Moreover, when loss function is L2-norm, the surrogate priority becomes TD error.
Although using TD error as priority is not so bad, one of the biggest problems at PER is that the priorities are always outdated. However, re-computing priorities of all transitions in the buffer at every sampling is too expensive.
LaBER first samples \(m\)-times larger batch from the buffer uniformly, then computes surrogate priorities for them, and samples final batch according to the priorities.
According to the authors, LaBER can be used together with non-uniform sampling like PER (they called it as PER-LaBER), however, the combination doesn’t improve the performance so much, even though there are additional computational cost.
cpprb provides three helper classes LaBERmean, LaBERlazy, and
LaBERmax. If you don’t have any special reasons, it is better to use
LaBERmean, which is theoretically and experimentally best. These
classes are constructed with following parameters;
| Parameters | Type | Description | Default |
|---|---|---|---|
batch_size |
int |
Desired final batch size (output size) | |
m |
int |
Multiplication factor (input size is m * batch_size) |
4 |
eps |
float |
Small positive constant to avoid 0 priority. (keyword only) | 1e-6 |
After construction, these classes can be used as functor. You can call
with priorities keyword and any other optional environment values.
laber = LaBERmean(32, 4)
sample = laber(priorities= [ ... ], # 32 * 4 surrogate priorities.
# optional: any additional environment values can be passed and subsampled together
obs= # ...
act= # ...
...
)2 Example Usage
The following pseudo code shows usage.
from cpprb import ReplayBuffer, LaBERmean
buffer_size = int(1e+6)
env_dict = # Define environment values
batch_size = 32
m = 4
n_iteration = int(1e+6)
rb = ReplayBuffer(buffer_size, env_dict)
laber = LaBERmean(batch_size, m)
env = # Create Env
policy = # Create Policy Network
observation = env.reset()
for _ in range(n_iteration):
action = policy(observation)
next_observation, reward, done, _ = env.step(action)
rb.add(obs=observation,
act=action,
rew=reward,
next_obs=next_observation,
done=done)
sample = rb.sample(batch_size * m)
absTD = # Calculate surrogate priority using network
idx_weights = laber(priorities=absTD)
indexes = idx_weights["indexes"]
weights = idx_weights["weights"]
policy.train((absTD[indexes] * weights).mean())
if done:
observation = env.reset()
rb.on_episode_end()
else:
observation = next_observationThe full example code are as follow;
# Example for Large Batch Experience Replay (LaBER)
# Ref: https://arxiv.org/abs/2110.01528
import os
import datetime
import numpy as np
import gym
import tensorflow as tf
from tensorflow.keras.models import Sequential, clone_model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.summary import create_file_writer
from cpprb import ReplayBuffer, LaBERmean
gamma = 0.99
batch_size = 64
N_iteration = int(1e+6)
target_update_freq = 10000
eval_freq = 1000
egreedy = 1.0
decay_egreedy = lambda e: max(e*0.99, 0.1)
# Use 4 times larger batch for initial uniform sampling
# Use LaBER-mean, which is the best variant
m = 4
LaBER = LaBERmean(batch_size, m)
# Log
dir_name = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join("logs", dir_name)
writer = create_file_writer(logdir + "/metrics")
writer.set_as_default()
# Env
env = gym.make('CartPole-v1')
eval_env = gym.make('CartPole-v1')
# For CartPole: input 4, output 2
model = Sequential([Dense(64,activation='relu',
input_shape=(env.observation_space.shape)),
Dense(env.action_space.n)])
target_model = clone_model(model)
# Loss Function
@tf.function
def Huber_loss(absTD):
return tf.where(absTD > 1.0, absTD, tf.math.square(absTD))
@tf.function
def MSE(absTD):
return tf.math.square(absTD)
loss_func = Huber_loss
optimizer = Adam()
buffer_size = 1e+6
env_dict = {"obs":{"shape": env.observation_space.shape},
"act":{"shape": 1,"dtype": np.ubyte},
"rew": {},
"next_obs": {"shape": env.observation_space.shape},
"done": {}}
# Nstep
nstep = 3
# nstep = False
if nstep:
Nstep = {"size": nstep, "rew": "rew", "next": "next_obs"}
discount = tf.constant(gamma ** nstep)
else:
Nstep = None
discount = tf.constant(gamma)
rb = ReplayBuffer(buffer_size,env_dict,Nstep=Nstep)
@tf.function
def Q_func(model,obs,act,act_shape):
return tf.reduce_sum(model(obs) * tf.one_hot(act,depth=act_shape), axis=1)
@tf.function
def DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
return gamma*tf.reduce_max(target(next_obs),axis=1)*(1.0-done) + rew
@tf.function
def Double_DQN_target_func(model,target,next_obs,rew,done,gamma,act_shape):
"""
Double DQN: https://arxiv.org/abs/1509.06461
"""
act = tf.math.argmax(model(next_obs),axis=1)
return gamma*tf.reduce_sum(target(next_obs)*tf.one_hot(act,depth=act_shape), axis=1)*(1.0-done) + rew
target_func = Double_DQN_target_func
def evaluate(model,env):
obs = env.reset()
total_rew = 0
while True:
Q = tf.squeeze(model(obs.reshape(1,-1)))
act = np.argmax(Q)
obs, rew, done, _ = env.step(act)
total_rew += rew
if done:
return total_rew
# Start Experiment
observation = env.reset()
# Warming up
for n_step in range(100):
action = env.action_space.sample() # Random Action
next_observation, reward, done, info = env.step(action)
rb.add(obs=observation,
act=action,
rew=reward,
next_obs=next_observation,
done=done)
observation = next_observation
if done:
observation = env.reset()
rb.on_episode_end()
n_episode = 0
observation = env.reset()
for n_step in range(N_iteration):
if np.random.rand() < egreedy:
action = env.action_space.sample()
else:
Q = tf.squeeze(model(observation.reshape(1,-1)))
action = np.argmax(Q)
egreedy = decay_egreedy(egreedy)
next_observation, reward, done, info = env.step(action)
rb.add(obs=observation,
act=action,
rew=reward,
next_obs=next_observation,
done=done)
observation = next_observation
# Uniform sampling
sample = rb.sample(batch_size * m)
with tf.GradientTape() as tape:
tape.watch(model.trainable_weights)
Q = Q_func(model,
tf.constant(sample["obs"]),
tf.constant(sample["act"].ravel()),
tf.constant(env.action_space.n))
target_Q = tf.stop_gradient(target_func(model,target_model,
tf.constant(sample['next_obs']),
tf.constant(sample["rew"].ravel()),
tf.constant(sample["done"].ravel()),
discount,
tf.constant(env.action_space.n)))
tf.summary.scalar("Target Q", data=tf.reduce_mean(target_Q), step=n_step)
absTD = tf.math.abs(target_Q - Q)
# Sub-sample according to surrogate priorities
# When loss is L2 or Huber, and no activation at the last layer,
# |TD| is surrogate priority.
sample = LaBER(priorities=absTD)
indexes = tf.constant(sample["indexes"])
weights = tf.constant(sample["weights"])
absTD = tf.gather(absTD, indexes)
assert absTD.shape == weights.shape, f"BUG: absTD.shape: {absTD.shae}, weights.shape {weights.shape}"
loss = tf.reduce_mean(loss_func(absTD)*weights)
grad = tape.gradient(loss, model.trainable_weights)
optimizer.apply_gradients(zip(grad, model.trainable_weights))
tf.summary.scalar("Loss vs training step", data=loss, step=n_step)
if done:
observation = env.reset()
rb.on_episode_end()
n_episode += 1
if n_step % target_update_freq == 0:
target_model.set_weights(model.get_weights())
if n_step % eval_freq == eval_freq-1:
eval_rew = evaluate(model,eval_env)
tf.summary.scalar("episode reward vs training step",data=eval_rew,step=n_step)3 Notes
We add eps to avoid zero priority, however, the original
implementation don’t have it. If you don’t want to add small positive
constant, you can pass eps=0 to the constructor (aka. __init__).
4 Technical Detail
Since the surrogate priority usually requires network’s forward caluculation, we implement LaBER separately from replay buffer.
Then LaBERmean etc. become simple classes, so that they are
implemented as ordinal Python classes insted of Cython cdef classes.