Skip to content

Engine object does not retain state.output when using an IterableDataset #3372

@mtauraso

Description

@mtauraso

🐛 Engine object does not retain state.output when using an IterableDataset

When using Pytorch ignite with an IterableDataset, The engine.state object at the end of an epoch, state.output is consistently None. When using a map style dataset state.output contains the return value of the model's training function at the end of the epoch.

This can be reproduced with the following snippet of python

import torch
from torch.utils.data import DataLoader, IterableDataset
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import create_supervised_trainer

#device = torch.device("cuda")
device = torch.device("mps")
#device = torch.device("cpu")

#use_iterable = False
use_iterable = True

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    
        self.model = resnet18(num_classes=10)

        self.model.conv1 = self.model.conv1 = nn.Conv2d(
            1, 64, kernel_size=3, padding=1, bias=False
        )

    def forward(self, x):
        return self.model(x)
model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
class IterableMnist(IterableDataset):
    mnist = MNIST(download=True, root=".", transform=data_transform, train=False)
    def __iter__(self):
        for item in IterableMnist.mnist:
            yield item

mnist = IterableMnist() if use_iterable else MNIST(download=True, root=".", transform=data_transform, train=False)

train_loader = DataLoader(mnist, batch_size=128)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

trainer = create_supervised_trainer(model, optimizer, criterion, device)

state = trainer.run(train_loader, max_epochs=1)
print(state.output)

To run, set your device appropriately at the top, then set use_iterable to True or False. When use_iterable is True, the MNIST dataset will be an IterableDataset and nothing will be printed. When use_iterable is False, the MNIST dataset will be the usual Dataset and the loss value from resnet18 will be printed.

I would expect that if an IterableDataset is used the engine state.output member would have the same behavior as when Dataset is provided. It seems to me that using map style datasets (Dataset) produces correct behavior and iterator style datasets (IterableDataset) does not.

I believe the fix is to stop resetting state.output in these two places:

self.state.batch = self.state.output = None

self.state.batch = self.state.output = None

It appears to me that when the loop in the _run_once_on_dataset_* functions is terminated due to a StopIteration Exception, self.state.output has been reset unnecessarily.

Environment

  • PyTorch Version (e.g., 1.4): 2.4.0
  • Ignite Version (e.g., 0.3.0): 0.5.1
  • OS (e.g., Linux): OSX
  • How you installed Ignite (conda, pip, source): Installed from pip inside a conda environment
  • Python version: 3.10
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions