##### Copyright 2023 The TensorFlow Datasets Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TFDS for Jax and PyTorch

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/tfless_tfds"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/data_source.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/data_source.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/data_source.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

TFDS has always been framework-agnostic. For instance, you can easily load
datasets in
[NumPy format](https://www.tensorflow.org/datasets/api_docs/python/tfds/as_numpy)
for usage in Jax and PyTorch.

TensorFlow and its data loading solution
([`tf.data`](https://www.tensorflow.org/guide/data)) are first-class citizens in
our API by design.

We extended TFDS to support TensorFlow-less NumPy-only data loading. This can
be convenient for usage in ML frameworks such as Jax and PyTorch. Indeed,
for the latter users, TensorFlow can:

- reserve GPU/TPU memory;
- increase build time in CI/CD;
- take time to import at runtime.

TensorFlow is no longer a dependency to read datasets.

ML pipelines need a data loader to load examples, decode them, and present
them to the model. Data loaders use the
"source/sampler/loader" paradigm:

```
 TFDS dataset       ┌────────────────┐
   on disk          │                │
        ┌──────────►│      Data      │
|..|... │     |     │     source     ├─┐
├──┼────┴─────┤     │                │ │
│12│image12   │     └────────────────┘ │    ┌────────────────┐
├──┼──────────┤                        │    │                │
│13│image13   │                        ├───►│      Data      ├───► ML pipeline
├──┼──────────┤                        │    │     loader     │
│14│image14   │     ┌────────────────┐ │    │                │
├──┼──────────┤     │                │ │    └────────────────┘
|..|...       |     │     Index      ├─┘
                    │    sampler     │
                    │                │
                    └────────────────┘
```

- The data source is responsible for accessing and decoding examples from a TFDS
dataset on the fly.
- The index sampler is responsible for determining the order in which records
are processed. This is important to implement global transformations (e.g.,
global shuffling, sharding, repeating for multiple epochs) before reading any
records.
- The data loader orchestrates the loading by leveraging the data source and the
index sampler. It allows performance optimization (e.g., pre-fetching,
multiprocessing or multithreading).


## TL;DR

`tfds.data_source` is an API to create data sources:

1. for fast prototyping in pure-Python pipelines;
2. to manage data-intensive ML pipelines at scale.

## Setup

Let's install and import the needed dependencies:

In [2]:
!pip install array_record
!pip install grain-nightly
!pip install jax jaxlib
!pip install tfds-nightly

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds





Collecting grain-nightly


  Downloading grain_nightly-0.2.12.dev20250714-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (16 kB)


Collecting cloudpickle (from grain-nightly)


  Downloading cloudpickle-3.1.1-py3-none-any.whl.metadata (7.1 kB)
Collecting more-itertools>=9.1.0 (from grain-nightly)


  Downloading more_itertools-10.7.0-py3-none-any.whl.metadata (37 kB)




Downloading grain_nightly-0.2.12.dev20250714-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (488 kB)


Downloading more_itertools-10.7.0-py3-none-any.whl (65 kB)
Downloading cloudpickle-3.1.1-py3-none-any.whl (20 kB)


Installing collected packages: more-itertools, cloudpickle, grain-nightly


Successfully installed cloudpickle-3.1.1 grain-nightly-0.2.12.dev20250714 more-itertools-10.7.0


Collecting jax


  Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB)


Collecting jaxlib
  Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.3 kB)


Downloading jax-0.6.2-py3-none-any.whl (2.7 MB)


Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)


Installing collected packages: jaxlib, jax


Successfully installed jax-0.6.2 jaxlib-0.6.2


Collecting tfds-nightly


  Downloading tfds_nightly-4.9.9.dev202508060045-py3-none-any.whl.metadata (11 kB)








Downloading tfds_nightly-4.9.9.dev202508060045-py3-none-any.whl (5.3 MB)


Installing collected packages: tfds-nightly


Successfully installed tfds-nightly-4.9.9.dev202508060045


## Data sources

Data sources are basically Python sequences. So they need to implement the
following protocol:

```python
from typing import SupportsIndex

class RandomAccessDataSource(Protocol):
  """Interface for datasources where storage supports efficient random access."""

  def __len__(self) -> int:
    """Number of records in the dataset."""

  def __getitem__(self, key: SupportsIndex) -> Any:
    """Retrieves the record for the given key."""
```

The underlying file format needs to support efficient random access. At the
moment, TFDS relies on [`array_record`](https://github.com/google/array_record).

[`array_record`](https://github.com/google/array_record) is a new file format
derived from [Riegeli](https://github.com/google/riegeli), achieving a new
frontier of IO efficiency. In particular, ArrayRecord supports parallel read,
write, and random access by record index. ArrayRecord builds on top of Riegeli
and supports the same compression algorithms.

[`fashion_mnist`](https://www.tensorflow.org/datasets/catalog/fashion_mnist) is
a common dataset for computer vision. To retrieve an ArrayRecord-based data
source with TFDS, simply use:

In [3]:
ds = tfds.data_source('fashion_mnist')



[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1...[0m


2025-08-06 11:21:51.396764: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754479311.424822   11180 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754479311.434227   11180 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754479311.456546   11180 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754479311.456579   11180 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754479311.456582   11180 computation_placer.cc:177] computation placer alr

2025-08-06 11:21:55.893755: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


[1mDataset fashion_mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/fashion_mnist/3.0.1. Subsequent calls will reuse this data.[0m


`tfds.data_source` is a convenient wrapper. It is equivalent to:

In [4]:
builder = tfds.builder('fashion_mnist', file_format='array_record')
builder.download_and_prepare()
ds = builder.as_data_source()

This outputs a dictionary of data sources:

```
{
  'train': DataSource(name=fashion_mnist, split='train', decoders=None),
  'test': DataSource(name=fashion_mnist, split='test', decoders=None),
}
```

Once `download_and_prepare` has run, and you generated the record files, we
don't need TensorFlow anymore. Everything will happen in Python/NumPy!

Let's check this by uninstalling TensorFlow and re-loading the data source
in another subprocess:

In [5]:
!pip uninstall -y tensorflow

Found existing installation: tensorflow 2.19.0


Uninstalling tensorflow-2.19.0:


  Successfully uninstalled tensorflow-2.19.0


In [6]:
%%writefile no_tensorflow.py
import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds

try:
  import tensorflow as tf
except ImportError:
  print('No TensorFlow found...')

ds = tfds.data_source('fashion_mnist')
print('...but the data source could still be loaded...')
ds['train'][0]
print('...and the records can be decoded.')

Writing no_tensorflow.py


In [7]:
!python no_tensorflow.py

No TensorFlow found...


...but the data source could still be loaded...
...and the records can be decoded.


In future versions, we are also going to make the dataset preparation
TensorFlow-free.

A data source has a length:

In [8]:
len(ds['train'])

60000

Accessing the first element of the dataset:

In [9]:
%%timeit
ds['train'][0]



501 μs ± 5.76 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


...is just as cheap as accessing any other element. This is the definition of
[random access](https://en.wikipedia.org/wiki/Random_access):

In [10]:
%%timeit
ds['train'][1000]

502 μs ± 5.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Features now use NumPy DTypes (rather than TensorFlow DTypes). You can inspect
the features with:

In [11]:
features = tfds.builder('fashion_mnist').info.features

You'll find more information about
[the features in our documentation](https://www.tensorflow.org/datasets/api_docs/python/tfds/features).
Here we can notably retrieve the shape of the images, and the number of classes:

In [12]:
shape = features['image'].shape
num_classes = features['label'].num_classes

## Use in pure Python

You can consume data sources in Python by iterating over them:

In [13]:
for example in ds['train']:
  print(example)
  break

{'image': array([[[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 18],
        [ 77],
        [227],
        [227],
        [208],
        [210],
        [225],
        [216],
        [ 85],
        [ 32],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 61],
        [100],
        [ 97],
        [ 80],
        [ 57],
        [117],
        [227],
        [238],
        [115],
        [ 49],
        [ 78],
        [106],
        [108],
        [ 71],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [ 81],
        [105],
        [ 80],
        [ 6

If you inspect elements, you will also notice that all features are already
decoded using NumPy. Behind the scenes, we use [OpenCV](https://opencv.org)
by default because it is fast. If you don't have OpenCV installed, we default
to [Pillow](python-pillow.org) to provide lightweight and fast image
decoding.

```
{
  'image': array([[[0], [0], ..., [0]],
                  [[0], [0], ..., [0]]], dtype=uint8),
  'label': 2,
}
```

**Note**: Currently, the feature is only available for `Tensor`, `Image` and
`Scalar` features. The `Audio` and `Video` features will come soon. Stay tuned!

## Use with PyTorch

PyTorch uses the source/sampler/loader paradigm. In Torch, "data sources" are
called "datasets".
[`torch.utils.data`](https://pytorch.org/docs/stable/data.html) contains all the
details you need to know to build efficient input pipelines in Torch.

TFDS data sources can be used as regular
[map-style datasets](https://pytorch.org/docs/stable/data.html#map-style-datasets).

First we install and import Torch:

In [14]:
!pip install torch

from tqdm import tqdm
import torch

Collecting torch


  Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)


Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)


Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)


Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)


Collecting nvidia-cuda-runtime-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)


Collecting nvidia-cuda-cupti-cu12==12.6.80 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)


  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.6.4.1 (from torch)


  Downloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.3.0.4 (from torch)
  Downloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)


Collecting nvidia-curand-cu12==10.3.7.77 (from torch)
  Downloading nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)


Collecting nvidia-cusolver-cu12==11.7.1.2 (from torch)
  Downloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)


Collecting nvidia-cusparse-cu12==12.5.4.2 (from torch)
  Downloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)


Collecting nvidia-cusparselt-cu12==0.6.3 (from torch)
  Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting nvidia-nccl-cu12==2.26.2 (from torch)


  Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)
Collecting nvidia-nvtx-cu12==12.6.77 (from torch)


  Downloading nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.6.85 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.5 kB)


Collecting nvidia-cufile-cu12==1.11.1.6 (from torch)
  Downloading nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.5 kB)


Collecting triton==3.3.1 (from torch)
  Downloading triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (1.5 kB)


Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)


  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)


Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl (821.2 MB)


Downloading nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (393.1 MB)


Downloading nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.9 MB)


Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl (23.7 MB)


Downloading nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (897 kB)
Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl (571.0 MB)


Downloading nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (200.2 MB)


Downloading nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.1 MB)
Downloading nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (56.3 MB)


Downloading nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (158.2 MB)


Downloading nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (216.6 MB)


Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl (156.8 MB)


Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (201.3 MB)


Downloading nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (19.7 MB)


Downloading nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89 kB)


Downloading triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (155.6 MB)


Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)


Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
Downloading filelock-3.18.0-py3-none-any.whl (16 kB)


Installing collected packages: nvidia-cusparselt-cu12, mpmath, triton, sympy, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufile-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, filelock, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torch


  Attempting uninstall: nvidia-nvjitlink-cu12
    Found existing installation: nvidia-nvjitlink-cu12 12.9.86
    Uninstalling nvidia-nvjitlink-cu12-12.9.86:
      Successfully uninstalled nvidia-nvjitlink-cu12-12.9.86


  Attempting uninstall: nvidia-nccl-cu12
    Found existing installation: nvidia-nccl-cu12 2.27.7
    Uninstalling nvidia-nccl-cu12-2.27.7:
      Successfully uninstalled nvidia-nccl-cu12-2.27.7


  Attempting uninstall: nvidia-curand-cu12
    Found existing installation: nvidia-curand-cu12 10.3.10.19
    Uninstalling nvidia-curand-cu12-10.3.10.19:
      Successfully uninstalled nvidia-curand-cu12-10.3.10.19


  Attempting uninstall: nvidia-cuda-runtime-cu12
    Found existing installation: nvidia-cuda-runtime-cu12 12.9.79


    Uninstalling nvidia-cuda-runtime-cu12-12.9.79:
      Successfully uninstalled nvidia-cuda-runtime-cu12-12.9.79
  Attempting uninstall: nvidia-cuda-nvrtc-cu12


    Found existing installation: nvidia-cuda-nvrtc-cu12 12.9.86
    Uninstalling nvidia-cuda-nvrtc-cu12-12.9.86:
      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.9.86


  Attempting uninstall: nvidia-cuda-cupti-cu12
    Found existing installation: nvidia-cuda-cupti-cu12 12.9.79
    Uninstalling nvidia-cuda-cupti-cu12-12.9.79:
      Successfully uninstalled nvidia-cuda-cupti-cu12-12.9.79


  Attempting uninstall: nvidia-cublas-cu12
    Found existing installation: nvidia-cublas-cu12 12.9.1.4


    Uninstalling nvidia-cublas-cu12-12.9.1.4:
      Successfully uninstalled nvidia-cublas-cu12-12.9.1.4


  Attempting uninstall: nvidia-cusparse-cu12
    Found existing installation: nvidia-cusparse-cu12 12.5.10.65
    Uninstalling nvidia-cusparse-cu12-12.5.10.65:
      Successfully uninstalled nvidia-cusparse-cu12-12.5.10.65


  Attempting uninstall: nvidia-cufft-cu12
    Found existing installation: nvidia-cufft-cu12 11.4.1.4
    Uninstalling nvidia-cufft-cu12-11.4.1.4:
      Successfully uninstalled nvidia-cufft-cu12-11.4.1.4


  Attempting uninstall: nvidia-cudnn-cu12
    Found existing installation: nvidia-cudnn-cu12 9.11.0.98
    Uninstalling nvidia-cudnn-cu12-9.11.0.98:
      Successfully uninstalled nvidia-cudnn-cu12-9.11.0.98


  Attempting uninstall: nvidia-cusolver-cu12
    Found existing installation: nvidia-cusolver-cu12 11.7.5.82
    Uninstalling nvidia-cusolver-cu12-11.7.5.82:
      Successfully uninstalled nvidia-cusolver-cu12-11.7.5.82


Successfully installed filelock-3.18.0 mpmath-1.3.0 nvidia-cublas-cu12-12.6.4.1 nvidia-cuda-cupti-cu12-12.6.80 nvidia-cuda-nvrtc-cu12-12.6.77 nvidia-cuda-runtime-cu12-12.6.77 nvidia-cudnn-cu12-9.5.1.17 nvidia-cufft-cu12-11.3.0.4 nvidia-cufile-cu12-1.11.1.6 nvidia-curand-cu12-10.3.7.77 nvidia-cusolver-cu12-11.7.1.2 nvidia-cusparse-cu12-12.5.4.2 nvidia-cusparselt-cu12-0.6.3 nvidia-nccl-cu12-2.26.2 nvidia-nvjitlink-cu12-12.6.85 nvidia-nvtx-cu12-12.6.77 sympy-1.14.0 torch-2.7.1 triton-3.3.1


We already defined data sources for training and testing (respectively,
`ds['train']` and `ds['test']`). We can now define the sampler and the loaders:

In [15]:
batch_size = 128
train_sampler = torch.utils.data.RandomSampler(ds['train'], num_samples=5_000)
train_loader = torch.utils.data.DataLoader(
    ds['train'],
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    ds['test'],
    sampler=None,
    batch_size=batch_size,
)

Using PyTorch, we train and evaluate a simple logistic regression on the first
examples:

In [16]:
class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)


model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

Training...


  0%|          | 0/40 [00:00<?, ?it/s]

  2%|▎         | 1/40 [00:00<00:07,  4.99it/s]

 12%|█▎        | 5/40 [00:00<00:01, 17.78it/s]

 22%|██▎       | 9/40 [00:00<00:01, 24.28it/s]

 32%|███▎      | 13/40 [00:00<00:00, 28.10it/s]

 42%|████▎     | 17/40 [00:00<00:00, 30.31it/s]

 52%|█████▎    | 21/40 [00:00<00:00, 31.57it/s]

 62%|██████▎   | 25/40 [00:00<00:00, 32.61it/s]

 72%|███████▎  | 29/40 [00:01<00:00, 32.92it/s]

 82%|████████▎ | 33/40 [00:01<00:00, 33.50it/s]

 92%|█████████▎| 37/40 [00:01<00:00, 33.93it/s]

100%|██████████| 40/40 [00:01<00:00, 30.50it/s]




Testing...


  0%|          | 0/79 [00:00<?, ?it/s]

  5%|▌         | 4/79 [00:00<00:02, 35.73it/s]

 10%|█         | 8/79 [00:00<00:01, 35.95it/s]

 15%|█▌        | 12/79 [00:00<00:01, 36.01it/s]

 20%|██        | 16/79 [00:00<00:01, 36.02it/s]

 25%|██▌       | 20/79 [00:00<00:01, 36.03it/s]

 30%|███       | 24/79 [00:00<00:01, 36.03it/s]

 35%|███▌      | 28/79 [00:00<00:01, 36.10it/s]

 41%|████      | 32/79 [00:00<00:01, 36.20it/s]

 46%|████▌     | 36/79 [00:00<00:01, 36.19it/s]

 51%|█████     | 40/79 [00:01<00:01, 36.18it/s]

 56%|█████▌    | 44/79 [00:01<00:00, 36.19it/s]

 61%|██████    | 48/79 [00:01<00:00, 36.24it/s]

 66%|██████▌   | 52/79 [00:01<00:00, 36.24it/s]

 71%|███████   | 56/79 [00:01<00:00, 36.16it/s]

 76%|███████▌  | 60/79 [00:01<00:00, 36.11it/s]

 81%|████████  | 64/79 [00:01<00:00, 36.16it/s]

 86%|████████▌ | 68/79 [00:01<00:00, 36.18it/s]

 91%|█████████ | 72/79 [00:01<00:00, 36.28it/s]

 96%|█████████▌| 76/79 [00:02<00:00, 36.29it/s]

100%|██████████| 79/79 [00:02<00:00, 36.54it/s]


Accuracy: 68.18%





## Use with JAX

[Grain](https://github.com/google/grain) is a library for reading data for
training and evaluating JAX models. It's open source, fast and deterministic.
Grain uses the source/sampler/loader paradigm, so we can re-use
`tfds.data_source`:

In [17]:
import grain.python as pygrain
import numpy as np

data_source = tfds.data_source("fashion_mnist", split="train")

# To shuffle the data, use a sampler:
sampler = pygrain.IndexSampler(
    num_records=5,
    num_epochs=1,
    shard_options=pygrain.NoSharding(),
    shuffle=True,
    seed=0,
)

Transformations are defined as classes and can be `BatchTransform`,
`FilterTransform` or `MapTransform`:

In [18]:
class ImageToText(pygrain.MapTransform):
  """Maps an image to text."""

  LABEL_TO_TEXT = {
      0: "zero",
      1: "one",
      2: "two",
      3: "three",
      4: "four",
      5: "five",
      6: "six",
      7: "seven",
      8: "height",
      9: "nine",
  }

  def map(self, element: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    label = element["label"]
    text = self.LABEL_TO_TEXT[label]
    element["text"] = text
    return element

# You can chain transformations in a list:
operations = [ImageToText()]

Finally, the data loader takes care of orchestrating the loading. You can scale
up with multiprocessing to enjoy both the flexibility of Python and the
performance of a data loader:

In [19]:
loader = pygrain.DataLoader(
    data_source=data_source,
    operations=operations,
    sampler=sampler,
    worker_count=0,  # Scale to multiple workers in multiprocessing
)

for element in loader:
  print(element["text"])

two
one
one
height
four


## Read more

For more information, please refer to [`tfds.data_source`](https://www.tensorflow.org/datasets/api_docs/python/tfds/data_source) API doc.