Skip to content

ChristophReich1996/SMURF

Repository files navigation

SMURF: Self-Teaching Multi-Frame Unsupervised RAFT with Full-Image Warping

License

Unofficial PyTorch (inference only) reimplementation of the CVPR 2021 paper SMURF: Self-Teaching Multi-Frame Unsupervised RAFT with Full-Image Warping by Austin Stone et al. (Google Research).

1

Figure taken from paper.

Requirements

To perform inference with the SMURF RAFT model it is just required to install PyTorch and TorchVision. See requirements.txt for details.

Note: to convert the original checkpoints it is also required to install Tensorflow 2.0.

Port weights

As the official SMURF implementation does not provide a license for their checkpoints this repo does not include the converted PyTorch checkpoints. However, you can convert the official Tensorflow checkpoints with the provided convert_weights_to_pt.py script.

You can download the original checkpoints here.

For converting an original checkpoint (e.g. Kitti checkpoint) to PyTorch just run:

python convert_weights_to_pt.py --tf_checkpoint "path/to/checkpoint/smurf-kitti-smurf-ckpt-31" --pt_checkpoint_path "smurf_kitti.pt"

Perform inference

To load the converted checkpoint and perform inference you can run:

from typing import List

import torch.nn as nn
import torchvision
from torch import Tensor

from smurf import raft_smurf

# Load images
image1: Tensor = torchvision.io.read_image("toy_data/reds/00000000.png", mode=torchvision.io.ImageReadMode.RGB)
image2: Tensor = torchvision.io.read_image("toy_data/reds/00000004.png", mode=torchvision.io.ImageReadMode.RGB)
# Normalize images to the pixe range of [-1, 1]
image1 = 2.0 * (image1 / 255.0) - 1.0
image2 = 2.0 * (image2 / 255.0) - 1.0
# Init SMURF RAFT model
model: nn.Module = raft_smurf(checkpoint="smurf_kitti.pt")
# Predict optical flow
optical_flow: List[Tensor] = model(image1[None], image2[None])

The resulting flow should look like this:

1

A full inference script with flow visualization is provided in perform_inference.py.

Implementation details

This implementation is mainly based on TorchVision's RAFT implementation (BSD 3-Clause License). However, the official SMURF RAFT implementation contains minor implementation differences. This implementation has modified TorchVision's RAFT implementation to match the official SMURF RAFT implementation. For more details please refer to this GitHub issue.

Reference

@inproceedings{Stone2021,
    title={SMURF: Self-Teaching Multi-Frame Unsupervised RAFT with Full-Image Warping},
    author={Stone, Austin and Maurer, Daniel and Ayvaci, Alper and Angelova, Anelia and Jonschkowski, Rico},
    booktitle={CVPR},
    year={2021}
}
@inproceedings{Nah2019,
  author={Nah, Seungjun and Baik, Sungyong and Hong, Seokil and Moon, Gyeongsik and Son, Sanghyun and Timofte, Radu and Lee, Kyoung Mu},
  title={NTIRE 2019 Challenge on Video Deblurring and Super-Resolution: Dataset and Study},
  booktitle={CVPRW},
  year={2019}
}

About

PyTorch port (inference only) of the paper "SMURF: Self-Teaching Multi-Frame Unsupervised RAFT with Full-Image Warping" [CVPR 2021].

Topics

Resources

License

Stars

Watchers

Forks

Languages