Preprint
I N - CONTEXT R EINFORCEMENT L EARNING
WITH A LGORITHM D ISTILLATION
Michael Laskin∗ Luyu Wang∗ Junhyuk Oh Emilio Parisotto Stephen Spencer
Richie Steigerwald DJ Strouse Steven Hansen Angelos Filos Ethan Brooks
Maxime Gazeau Himanshu Sahni Satinder Singh Volodymyr Mnih
arXiv:2210.14215v1 [[Link]] 25 Oct 2022
DeepMind
A BSTRACT
We propose Algorithm Distillation (AD), a method for distilling reinforcement
learning (RL) algorithms into neural networks by modeling their training his-
tories with a causal sequence model. Algorithm Distillation treats learning to
reinforcement learn as an across-episode sequential prediction problem. A dataset
of learning histories is generated by a source RL algorithm, and then a causal
transformer is trained by autoregressively predicting actions given their preceding
learning histories as context. Unlike sequential policy prediction architectures that
distill post-learning or expert sequences, AD is able to improve its policy entirely
in-context without updating its network parameters. We demonstrate that AD can
reinforcement learn in-context in a variety of environments with sparse rewards,
combinatorial task structure, and pixel-based observations, and find that AD learns
a more data-efficient RL algorithm than the one that generated the source data.
Private & Confidential
Data Generation
Task 1 o0 a0 r0 o1 a1 r1 … oT-1 aT-1 rT-1 oT aT rT
o0 a0 r0 o1 a1 r1 … oT-1 aT-1 rT-1 oT aT rT RL algorithm
o0 a0 r0 o1 a1 r1 … oT-1 aT-1 rT-1 oT aT rT learning histories
Task
learning progress
Model Training
Predict actions using
across-episodic contexts
Causal Transformer
Figure 1: Algorithm Distillation (AD) has two steps – (i) a dataset of learning histories is collected from
individual single-task RL algorithms solving different tasks; (ii) a causal transformer predicts actions from
these histories using across-episodic contexts. Since the RL policy improves throughout the learning histories,
by predicting actions accurately AD learns to output an improved policy relative to the one seen in its context.
AD models state-action-reward tokens, and does not condition on returns.
∗
Equal contribution. Corresponding authors: mlaskin, luyuwang, vminh@[Link].
1
Preprint
1 I NTRODUCTION
Transformers have emerged as powerful neural network architectures for sequence modeling (Vaswani
et al., 2017). A striking property of pre-trained transformers is their ability to adapt to downstream
tasks through prompt conditioning or in-context learning. After pre-training on large offline datasets,
large transformers have been shown to generalize to downstream tasks in text completion (Brown
et al., 2020), language understanding (Devlin et al., 2018), and image generation (Yu et al., 2022).
Recent work demonstrated that transformers can also learn policies from offline data by treating offline
Reinforcement Learning (RL) as a sequential prediction problem. While Chen et al. (2021) showed
that transformers can learn single-task policies from offline RL data via imitation learning, subsequent
work showed that transformers can also extract multi-task policies in both same-domain (Lee et al.,
2022) and cross-domain settings (Reed et al., 2022). These works suggest a promising paradigm
for extracting generalist multi-task policies – first collect a large and diverse dataset of environment
interactions, then extract a policy from the data via sequential modeling. We refer to the family
of approaches that learns policies from offline RL data via imitation learning as Offline Policy
Distillation, or simply Policy Distillation1 (PD).
Despite its simplicity and scalability, a substantial drawback of PD is that the resulting policy does
not improve incrementally from additional interaction with the environment. For instance, the Multi-
Game Decision Transformer (MGDT, Lee et al., 2022) learns a return-conditioned policy that plays
many Atari games while Gato (Reed et al., 2022) learns a policy that solves tasks across diverse
environments by inferring tasks through context, but neither method can improve its policy in-context
through trial and error. MGDT adapts the transformer to new tasks by finetuning the model weights
while Gato requires prompting with an expert demonstration to adapt to a new task. In short, Policy
Distillation methods learn policies but not Reinforcement Learning algorithms.
We hypothesize that the reason Policy Distillation does not improve through trial an error is that it
trains on data that does not show learning progress. Current methods either learn policies from data
that contains no learning (e.g. by distilling fixed expert policies) or data with learning (e.g. the replay
buffer of an RL agent) but with a context size that is too small to capture policy improvement.
Our key observation is that the sequential nature of learning within RL algorithm training could,
in principle, make it possible to model the process of reinforcement learning itself as a causal
sequence prediction problem. Specifically, if a transformer’s context is long enough to include policy
improvement due to learning updates it should be able to represent not only a fixed policy but a policy
improvement operator by attending to states, actions and rewards from previous episodes. This opens
the possibility that any RL algorithm can be distilled into a sufficiently powerful sequence model
such as a transformer via imitation learning, converting it into an in-context RL algorithm.
We present Algorithm Distillation (AD), a method that learns an in-context policy improvement
operator by optimizing a causal sequence prediction loss on the learning histories of an RL algorithm.
AD has two components. First, a large multi-task dataset is generated by saving the training histories
of an RL algorithm on many individual tasks. Next, a transformer models actions causally using the
preceding learning history as its context. Since the policy improves throughout the course of training
of the source RL algorithm, AD is forced to learn the improvement operator in order to accurately
model the actions at any given point in the training history. Crucially, the transformer context size
must be sufficiently large (i.e. across-episodic) to capture improvement in the training data. The full
method is shown in Fig. 1.
We show that by imitating gradient-based RL algorithms using a causal transformer with sufficiently
large contexts, AD can reinforcement learn new tasks entirely in-context. We evaluate AD across
a number of partially observed environments that require exploration, including the pixel-based
Watermaze (Morris, 1981) from DMLab (Beattie et al., 2016). We show that AD is capable of
in-context exploration, temporal credit assignment, and generalization. We also show that AD learns
a more data-efficient algorithm than the one that generated the source data for transformer training.
To the best of our knowledge, AD is the first method to demonstrate in-context reinforcement learning
via sequential modeling of offline data with an imitation loss.
1
What we refer to as Policy Distillation is similar to Rusu et al. (2016) but the policy is distilled from offline
data, not a teacher network.
2
Preprint
2 BACKGROUND
Partially Observable Markov Decision Processes: A Markov Decision Process (MDP) consists
of states s ∈ S, actions a ∈ A, rewards r ∈ R, a discount factor γ, and a transition probability
function p(st+1 |st , at ), where t is an integer denoting the timestep and (S, A) are state and action
spaces. In environments described by an MDP, at each timestep t the agent observes the state st ,
selects an action at ∼ π(·|st ) from its policy, and then observes the next state st+1 ∼ p(·|st , at )
sampled from the transition dynamics of the environment. In this work, we operate in the Partially
Observable Markov Decision Process (POMDP) setting where instead of states s ∈ S the agent
receives observations o ∈ O that only have partial information about the true state of the environment.
Full state information may be incomplete due to missing information about the goal in the
environment, which the agent must instead infer through rewards with memory, or because the
observations are pixel-based, or both.
Online and Offline Reinforcement Learning: ReinforcementPLearning algorithms aim to
maximize the return, defined as the cumulative sum of rewards t γ t rt , throughout an agent’s
lifetime or episode of training. RL algorithms broadly fall into two categories: on-policy
algorithms (Williams, 1992) where the agent directly maximizes a Monte-Carlo estimate of the
total returns or off-policy (Mnih et al., 2013; 2015) where an agent learns and maximizes a value
function that approximates the total future return. Most RL algorithms maximize returns through
trial-and-error by directly interacting with the environment. However, offline RL (Levine et al., 2020)
has recently emerged as an alternate and often complementary paradigm for RL where an agent
aims to extract return maximizing policies from offline data gathered by another agent. The offline
dataset consists of (s, a, r) tuples which are often used to train an off-policy agent, though other
algorithms for extracting return maximizing policies from offline data are also possible.
Self-Attention and Transformers The self-attention (Vaswani et al., 2017) operation begins by
projecting input data X with three separate matrices onto D-dimensional vectors called queries Q,
keys K, and values V . These vectors are then passed through the attention function:
√
Attention(Q, K, V ) = softmax(QK T / D)V. (1)
The QK T term computes an inner product between two projections of the input data X. The inner
product is then normalized and projected back to a D-dimensional vector with the scaling term V .
Transformers (Vaswani et al., 2017; Devlin et al., 2018; Brown et al., 2020) utilize self-attention as
a core part of the architecture to process sequential data such as text sequences. Transformers are
usually pre-trained with a self-supervised objective that predicts tokens within the sequential data.
Common prediction tasks include predicting randomly masked out tokens (Devlin et al., 2018) or
applying a causal mask and predicting the next token (Radford et al., 2018).
Offline Policy Distillation: We refer to the family of methods that treat offline Reinforcement
Learning as a sequential prediction problem as Offline Policy Distillation, or Policy Distillation
(PD) for brevity. Rather than learning a value function from offline data, PD extracts policies by
predicting actions in the offline data (i.e. behavior cloning) with a sequence model and either return
conditioning (Chen et al., 2021; Lee et al., 2022) or filtering out suboptimal data (Reed et al., 2022).
Initially proposed to learn single-task policies (Chen et al., 2021; Janner et al., 2021), PD was recently
extended to learn multi-task policies from diverse offline data (Lee et al., 2022; Reed et al., 2022).
In-Context Learning: In-context learning refers to the ability to infer tasks from context. For
example, large language models like GPT-3 (Brown et al., 2020) or Gopher (Rae et al., 2021) can
be directed at solving tasks such as text completion, code generation, and text summarization by
specifying the task through language as a prompt. This ability to infer the task from prompt is
often called in-context learning. We use the terms ‘in-weights learning’ and ‘in-context learning’
from prior work on sequence models (Brown et al., 2020; Chan et al., 2022) to distinguish between
gradient-based learning with parameter updates and gradient-free learning from context, respectively.
3 M ETHOD
Over the course of its lifetime a capable reinforcement learning (RL) agent will exhibit complex
behaviours, such as exploration, temporal credit assignment, and planning. Our key insight is that an
3
Preprint
agent’s actions, regardless of the environment it inhabits, its internal structure, and implementation,
can be viewed as a function of its past experience, which we refer to as its history. Formally, we write:
H 3 ht := (o0 , a0 , r0 , . . . , ot−1 , at−1 , rt−1 , ot , at , rt ) = (o≤t , r≤t , a≤t ) (2)
2
and we refer to a long history-conditioned policy as an algorithm:
P : H ∪ O → ∆(A), (3)
where ∆(A) denotes the space of probability distributions over the action space A. Eqn. (3) suggests
that, similar to a policy, an algorithm can be unrolled in an environment to generate sequences of ob-
servations, rewards, and actions. For brevity, we denote the algorithm as P and environment (i.e. task)
as M, such that the history of learning for any given task M is generated by the algorithm PM .
(O0 , A0 , R0 , . . . , OT , AT , RT ) ∼ PM . (4)
Here, we’re denoting random variables with uppercase Latin letters, e.g. O, A, R, and their values
with lowercase Latin letters, e.g. o, a, r. By viewing algorithms as long history-conditioned policies,
we hypothesize that any algorithm that generated a set of learning histories can be distilled into
a neural network by performing behavioral cloning over actions. Next, we present a method that,
provided agents’ lifetimes, learns a sequence model with behavioral cloning to map long histories
to distributions over actions.
3.1 A LGORITHM D ISTILLATION
Suppose the agents’ lifetimes, which we also call learning histories, are generated by the source
algorithm P source for many individual tasks {Mn }N n=1 , producing the dataset D:
n oN
(n) (n) (n) (n) (n) (n) source
D := o0 , a0 , r0 , . . . , oT , aT , rT ∼ PM n
. (5)
n=1
Then we distill the source algorithm’s behaviour into a sequence model that maps long histories
to probabilities over actions with a negative log likelihood (NLL) loss and refer to this process as
algorithm distillation (AD). In this work, we consider neural network models Pθ with parameters
θ which we train by minimizing the following loss function:
N T −1
(n) (n) (n)
X X
L(θ) := − log Pθ (A = at |ht−1 , ot ). (6)
n=1 t=1
Intuitively, a sequence model with fixed parameters that is trained with AD should amortise the source
RL algorithm P source and by doing so exhibit similarly complex behaviours, such as exploration
and temporal credit assignment. Since the RL policy improves throughout the learning history of
the source algorithm, accurate action prediction requires the sequence model to not only infer the
current policy from the preceding context but also infer the improved policy, therefore distilling
the policy improvement operator.
3.2 P RACTICAL I MPLEMENTATION
In practice, we implement AD as a two-step procedure. First, a dataset of learning histories is
collected by running an individual gradient-based RL algorithm on many different tasks. Next, a
sequence model with multi-episodic context is trained to predict actions from the histories. We
describe these two steps below and detail the full practical implementation in Algorithm 1.
Dataset Generation: A dataset of learning histories is collected by training N individual single-task
gradient-based RL algorithms. To prevent overfitting to any specific task during sequence model
training, a task M is sampled randomly from a task distribution for each RL run. The data generation
step is RL algorithm agnostic - any RL algorithm can be distilled. We show results distilling UCB
exploration (Lai & Robbins, 1985), an on-policy actor-critic (Mnih et al., 2016), and an off-policy
DQN (Mnih et al., 2013), in both distributed and single-stream settings. We denote the dataset
of learning histories as D in Eq. 5. Training the Sequence Model: Once a dataset of learning
histories D is collected, a sequential prediction model is trained to predict actions given the preceding
2
Long enough to span learning updates, e.g. across episodes.
4
Preprint
Figure 3: Adversarial Bandit (Section 5): AD, RL2 , and ED evaluated on a 10-arm bandit with 100 trials. The
source data for AD comes from learning histories from UCB (Lai & Robbins, 1985). During training, the reward
is distributed under odd arms 95% of the time and under even arms 95% of the time during evaluation. Both AD
and RL2 can in-context learn in-distribution tasks, but AD generalizes better out of distribution. Running RL2
with a transformer generally doesn’t offer an advantage over the original LSTM variant. ED performs poorly
both in and out of distribution relative to AD and RL2 . Scores are normalized relative to UCB.
which it receives a reward of r = 1 once again. Otherwise, the reward is r = 0. The room size is 9 × 9
making the task space combinatorial with 812 = 6561 possible tasks. This environment is similar to
the one considered in Chen et al. (2021) except the key and door are invisible and the reward is semi-
sparse (r = 1 for both key and the door). The agent is randomly reset. The episode length is 50 steps.
DMLab Watermaze: a partially observable 3D visual DMLab environment based on the classic
Morris Watermaze (Morris, 1981). The task is to navigate the water maze to find a randomly spawned
trap-door. The maze walls have color patterns that can be used to remember the goal location.
Observations are pixel images of size 72 × 96 × 3. There are 8 possible actions in total, including
going forward, backward, left, or right, rotating left or right, and rotating left or right while going
forward. The episode length is 50, and the agent resets at the center of the map. Similar to Dark
Room, the agent cannot see the location of the goal from the observations and must infer it through
the reward of r = 1 if reached and r = 0 otherwise; however, the goal space is continuous and
therefore there are an infinite number of goals.
4.2 BASELINES
The main aim of this work is to investigate to what extent AD reinforcement learns in-context relative
to prior related work. AD is mostly closely related to Policy Distillation, where a policy is learned
with a sequential model from offline interaction data. In-context online meta-RL is also related
though not directly comparable to AD, since AD is an in-context offline meta-RL method. Still, we
consider both types of baselines to better contextualize our work. For a more detailed discussion of
these baseline choices we refer the reader to Appendix B. Our baselines include:
Expert Distillation (ED): this baseline is exactly the same as AD but the source data consists of expert
trajectories only, rather than learning histories. ED is most similar to Gato (Reed et al., 2022) except
ED models state-action-reward sequences like AD, while Gato models state-action sequences.
Source Algorithm: we compare AD to the gradient-based source RL algorithm that generates the
training data for distillation. We include running the source algorithm from scratch as a baseline to
compare the data-efficiency of in-context RL to the in-weights source algorithm.
RL2 (Duan et al., 2016): an online meta-RL algorithm where exploration and fast in-context adapta-
tion are learned jointly by maximizing a multi-episodic value function. RL2 is not directly comparable
to AD for similar reasons to why online and offline RL algorithms are not directly comparable – RL2
gets to interact with the environment during training while AD does not. We use RL2 asymptotic
performance as an approximate upper bound for AD.
4.3 E VALUATION
After pre-training, the AD transformer Pθ can reinforcement learn in-context. Evaluation is exactly
the same as with an in-weights RL algorithm except the learning happens entirely in-context without
updating the transformer network parameters. Given an MDP (or POMDP), the transformer interacts
with the environment and populates its own context (i.e. without demonstrations), where the context
is a queue containing the last c transitions. The transformer’s performance is then evaluated in terms