PixelTransformer - Sample Conditioned Signal Generation
PixelTransformer - Sample Conditioned Signal Generation
this joint distribution directly, we observe that it can be base architectures (Chen et al., 2020; Parmar et al., 2018)
further factorized as a product of conditional distributions such as Transformers (Vaswani et al., 2017).
using the chain rule:
While this line of work has led to impressive results, the
Y core distribution modeled is that of the ‘next’ value given
p(vg1 , vg2 , . . . , vgN |S0 ) = p(vgn |S0 , vg1 , . . . , vgn−1 )
‘previous’ values. More formally, while we aim to predict
n
p(vx |S) for arbitrary x, S, the prior autoregressive genera-
Denoting by Sn ≡ S0 ∪ {vgj }nj=1 , we obtain: tive models only infer this for cases where S contains pixels
in some sequential (e.g. raster) order and x is the immediate
‘next’ position. Although using masked convolutions can
Y
p(I|S0 ) = p(vgn |Sn−1 ) (1)
n
allow handling many possible inference orders (Jain et al.,
2020), the limited receptive field of convolutions still limits
Sample Conditioned Value Prediction. The key observa- such orders to locally continuous sequences. Our work can
tion from Eq. 1 is that all the factors are in the form of therefore be viewed as a generalization of previous ‘sequen-
p(vx |S). That is, the only queries we need to answer are: tial’ autoregressive models in two ways: a) allowing any
‘given some observed samples S, what is the distribution query position x, and b) handling arbitrary samples S for
of possible values at location x’? To learn a sample condi- conditioning. This allows us to answer questions that prior
tioned generative model for images, we therefore propose autoregressive models cannot e.g. ‘if the top-left pixel is
to learn a function fθ to infer p(vx |S) for arbitrary inputs x blue, how likely is the bottom-right one to be green?’, ‘what
and S. Concretely, we formulate our task as that of learn- is the mean image given some observations?’, or ‘given
ing a function fθ (x, {(xk , vk )}) that can predict the value values of 10 specific pixels, sample likely images’.
distribution at an arbitrary query location x given a set of
Implicit Neural Representations. There has been a grow-
arbitrary sample (position, value) pairs {(xk , vk )}.
ing interest in learning neural networks to represent 3D tex-
In summary: tured scenes (Sitzmann et al., 2019), radiance fields (Milden-
hall et al., 2020; Martin-Brualla et al., 2021; Zhang et al.,
• The task of inferring p(I|S0 ) can be reduced to queries
2020) or more generic spatial signals (Sitzmann et al., 2020;
of the form p(vx |S).
Tancik et al., 2020). The overall approach across these
• We propose to learn a function fθ (x, {(xk , vk )}) that methods is to represent the underlying signal by learning
can predict p(vx |{vxk }) for arbitrary inputs. a function gφ that maps query positions x to correspond-
ing values v (e.g. pixel location to intensity). Our learned
While we used images as a motivating example, our for- fθ (·, {(xk , vk )}) can similarly be thought of as mapping
mulation is also applicable for modeling distributions of query positions to a corresponding value (distribution),
other dense spatially varying signals. For RGB images, while being conditioned on some sample values. A key
x ∈ R2 , v ∈ R3 , but other spatial signals e.g. polynomials difference however, is the ability to generalize – the above
(x ∈ R1 , v ∈ R1 ), 3D shapes represented as Signed Dis- mentioned approaches learn an independent network per
tance Fields, (x ∈ R3 , v ∈ R1 ) or videos (x ∈ R3 , v ∈ R3 ) instance e.g. a separate gφ is used to model each scene,
can also be handled by learning fθ (x, {(xk , vk )}) of the therefore requiring from thousands to millions of samples to
corresponding form (see Section 6). fit gφ for a specific scene. In contrast, our approach uses a
common fθ across all instances and can therefore generalize
3. Related Work to unseen ones given only a sparse set of samples. Although
Autoregressive Generative Models. Closely related to our some recent approaches (Xu et al., 2019; Park et al., 2019;
work, autoregressive generative modeling approaches also Mescheder et al., 2019) have shown similar ability to gener-
factorize the joint distribution into per-location conditional alize and infer novel 3D shapes/scenes given input image(s),
distributions. Seminal works such as Wavenet (van den these cannot handle sparse input samples and do not allow
Oord et al., 2016a), PixelRNN (van den Oord et al., 2016c) inferring a distribution over the output space.
and PixelCNN (van den Oord et al., 2016b) showed that Latent Variable based Generative Models. Our approach,
we can learn the distribution over the values of the ‘next’ similar to sequential autoregressive models, factorizes the
timestep/pixel given the values of the previous ones, and image distribution as products of per-pixel distributions.
thereby learn a generative model for the corresponding An alternate approach to generative modeling, however,
domain (speech/images). Subsequent approaches have is to transform a prior distribution over latent variables
further improved over these works by modifying the to the output distribution via a learned decoder. Several
parametrization (Salimans et al., 2017), incorporating hier- approaches allow learning such a decoder by leveraging
archy (van den Oord et al., 2017; Razavi et al., 2019), or diverse objectives e.g. adversarial loss (Goodfellow et al.,
(similar to ours) foregoing convolutions in favor of alternate
PixelTransformer: Sample Conditioned Signal Generation
Figure 3. Inferred Mean Images. We visualize the mean image predicted by our learned model on random instances of the Cat Faces
dataset. Top row: ground-truth image. Rows 2-8: Predictions using increasing number of observed pixels |S|.
account not just the initial samples S0 , but also the subse- 5. Experiments
quent n − 1 samples (hence the difference from ωn in Eq. 4). To qualitatively and quantitatively demonstrate the efficacy
vn0 represents a value then sampled for the pixel gn from of our approach, we consider the task of generating images
the distribution parametrized by ωn . given a set of pixels with known values. The goal of our
Randomized Sampling Order. While we sample the val- experiments is twofold – a) to validate that our predictions
ues one pixel at a time, the ordering of pixels g1 , . . . , gN account for the observed pixels, and b) to show that the
need not correspond to anything specific e.g. it is not nec- generated samples are diverse and plausible.
essary that g1 should be the top-left pixel and gN be the Datasets. We examine our approach on three different im-
bottom-right one. In fact, as our model fθ is trained using age datasets – CIFAR10 (Krizhevsky, 2009), MNIST (Le-
arbitrary sets of samples S, using a structured sampling or- Cun et al., 1998), and the Cat Faces (Wu et al., 2020) dataset
dering e.g. raster order would make the testing setup differ while using the standard image splits. Note that we only
from training. Instead, for every sample I ∼ p(I|S) that we require the images for training – class or attribute labels are
draw, we use a new random order in which the pixels of the not leveraged for learning our models i.e. even on CIFAR10,
image grid are sampled. we learn a class-agnostic generative model.
Sidestepping Memory Bottlenecks. As Eq. 5 indicates,
the input to fθ when sampling the (n + 1)th pixel is a set of Training Setup. We vary the number of observed pixels S
size K +n – the initial K observations and the subsequent n randomly between 4 and 2048 (with uniform sampling in
samples. Unfortunately, our model’s memory requirement, log-scale), while the number of query samples Q is set to
due to the self-attention modules, grows cubically with this 2048. During training, the locations x are treated as varying
input size. This makes it infeasible to autoregressively sam- over a continuous domain, using bilinear sampling to obtain
ple a very large number of pixels. However, we empirically the corresponding value – this helps our implementation
observe that given a sufficient number of (random) samples, be agnostic to the image resolution in the dataset. While
subsequent pixel value distributions do not exhibit a high we train a separate network fθ for each dataset, we use the
variance. We leverage this observation to design a hybrid exact same model, hyper-parameters etc. across them.
sampling strategy. When generating an image with N pix- Qualitative Results: Mean Image Prediction. We first
els, we sample the first N 0 (typically 2048) autoregressively examine the expected image I¯ inferred by our model given
i.e. following Eq. 5 and Eq. 6. For the remaining N − N 0 some samples S. We visualize in Figure 3 our predictions
pixels, we simply use their mean value estimate conditioned on the Cat Faces dataset using varying number of input
on the initial and generated K + N 0 samples (using Eq. 4). samples. We observe that even when using as few as 4
While this may lead to some loss in detail, we qualitatively pixels in S, our model predicts a cat-like mean image that,
show that the effects are not prohibitive and that the sample with some exceptions, captures the coarse color accurately.
diversity is preserved.
PixelTransformer: Sample Conditioned Signal Generation
Figure 4. Image Samples. Sample images generated by our learned model on three datasets (left: MNIST, middle: Cat Faces, right:
CIFAR10) given |S| = 32 observed pixels. Top row: ground-truth image from which S is drawn. Row 2: A nearest neighbor visualization
of S – for each image pixel we assign it the color of the closest observed sample in S. Rows 3-5: Randomly sampled images from p(I|S).
Accuracy
SSIM 0.6 0.6
0.4 0.4
0.2 Decoder + Optimization 0.2 GT Image
Ours (Mean Image) Mean Image
Ours (Image Samples) Image Samples
0.0 0 500 1000 1500 2000 0.0 0 500 1000 1500 2000
|S| |S|
Figure 8. Shape Generation. Sample 3D shapes generated given |S| = 32 observed SDF values at random locations. Top row: ground-
truth 3D shape. Row 2: A visualization of S – a sphere is centred at each position with color indicating value (red implies higher SDF).
Rows 3-5: Randomly sampled 3D shapes from our predicted conditional distribution.
Interestingly, we see that even if using images generated is applicable beyond images. In particular, assuming the
from as few as 16 pixels, we obtain about a 30% classi- availability of (unlabeled) examples, our approach can learn
fication accuracy (or over 60% with 128 pixels). As we to generate any dense spatial signal given some (position,
observe more pixels, the accuracy matches that of using the value) samples. In this section, we empirically demonstrate
ground-truth images. Finally, we see that using the sampled this by learning to generate 1D (polynomial) and 3D (shapes
images yields better results compared to the mean image, as and videos) signals using our framework.
the sampled ones look more ‘real’.
We would like to emphasize that across these settings, where
we are learning to generate rather different spatial signals,
6. Beyond Images: 1D and 3D Signals we use the same training objective and model design. That
While we leveraged our proposed framework for generat- is, except for the dimensionality of input/output layers and
ing images given some pixel observations, our formulation distribution parametrization to handle the corresponding
PixelTransformer: Sample Conditioned Signal Generation
Krizhevsky, A. Learning multiple layers of features from Sitzmann, V., Martel, J., Bergman, A., Lindell, D., and Wet-
tiny images. 2009. zstein, G. Implicit neural representations with periodic
activation functions. NeurIPS, 2020.
LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient-
based learning applied to document recognition. Proceed- Tancik, M., Srinivasan, P., Mildenhall, B., Fridovich-Keil,
ings of the IEEE, 86(11):2278–2324, 1998. S., Raghavan, N., Singhal, U., Ramamoorthi, R., Barron,
J., and Ng, R. Fourier features let networks learn high fre-
Li, K. and Malik, J. Implicit maximum likelihood estimation.
quency functions in low dimensional domains. NeurIPS,
arXiv preprint arXiv:1809.09087, 2018.
2020.
Martin-Brualla, R., Radwan, N., Sajjadi, M. S. M., Barron,
J. T., Dosovitskiy, A., and Duckworth, D. NeRF in the Thomee, B., Shamma, D. A., Friedland, G., Elizalde, B., Ni,
Wild: Neural Radiance Fields for Unconstrained Photo K., Poland, D., Borth, D., and Li, L.-J. Yfcc100m: The
Collections. In CVPR, 2021. new data in multimedia research. Communications of the
ACM, 59(2):64–73, 2016.
Mescheder, L., Oechsle, M., Niemeyer, M., Nowozin, S.,
and Geiger, A. Occupancy networks: Learning 3d recon- Ulyanov, D., Vedaldi, A., and Lempitsky, V. Deep image
struction in function space. In CVPR, 2019. prior. In CVPR, 2018.
Mildenhall, B., Srinivasan, P. P., Tancik, M., Barron, J. T., van den Oord, A., Dieleman, S., Zen, H., Simonyan, K.,
Ramamoorthi, R., and Ng, R. Nerf: Representing scenes Vinyals, O., Graves, A., Kalchbrenner, N., Senior, A., and
as neural radiance fields for view synthesis. In ECCV, Kavukcuoglu, K. Wavenet: A generative model for raw
2020. audio. arXiv preprint arXiv:1609.03499, 2016a.
PixelTransformer: Sample Conditioned Signal Generation
van den Oord, A., Kalchbrenner, N., Espeholt, L., Vinyals, Appendix
O., Graves, A., et al. Conditional image generation with
pixelcnn decoders. In NeurIPS, 2016b. Log-likelihood under Value Distribution. The pre-
dicted value distribution for a query position x is of the
van den Oord, A., Kalchbrenner, N., and Kavukcuoglu, form p(v; ω), where ω ≡ {(q b , µb , σ b )}B
b=1 . We reiterate
K. Pixel recurrent neural networks. arXiv preprint q b ∈ R1 is the probability of assignment to bin b, cb + µb
arXiv:1601.06759, 2016c. is the mean of the corresponding gaussian distribution with
uniform variance σ b ∈ R1 .
van den Oord, A., Vinyals, O., et al. Neural discrete repre-
sentation learning. In NeurIPS, 2017. Under this parametrization, we compute the log-likelihood
of a value v∗ by finding the closest bin b∗ , and computing
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, the log-likelihood of assignment to this bin as well as the
L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention log-probability of the value under the corresponding gaus-
is all you need. In NeurIPS, 2017. sian. We additionally use a weight α = 0.1 to balance the
classification and gaussian log-likelihood terms.
Vondrick, C., Pirsiavash, H., and Torralba, A. Generating
videos with scene dynamics. In NeurIPS, 2016. b∗ = argminb kv ∗ − cb k
∗ ∗
Wu, S., Rupprecht, C., and Vedaldi, A. Unsupervised learn- ∗ ∗ v∗ − cb − µb 2
log p(v ∗ ; ω) ≡ log q b − α(log σ b + ( ) )
ing of probably symmetric deformable 3d objects from σ b∗
images in the wild. In CVPR, 2020.
VAE Training and Inference. We train a variational auto-
Xu, Q., Wang, W., Ceylan, D., Mech, R., and Neumann,
encoder (Kingma & Welling, 2013) on the CIFAR10 dataset
U. Disn: Deep implicit surface network for high-quality
with a bottleneck layer of dimension 4 × 4 × 64 i.e. spa-
single-view 3d reconstruction. In NeurIPS, 2019.
tial size 4 and feature size 64. We consequently obtain a
Zhang, K., Riegler, G., Snavely, N., and Koltun, V. Nerf++: decoder D which we use for inference given some observed
Analyzing and improving neural radiance fields. arXiv samples S. Specifically, we optimize for an optimal latent
preprint arXiv:2010.07492, 2020. variable the minimizes the reconstruction loss for the ob-
served samples (with an additional prior biasing towards
the zero vector). Denoting by I(x) the value of image I
(bilinearly sampled) at position x, the image I ∗ inferred
using a decoder D by optimizing over S can be computed
as: