Skip to content

rw-ocrl/ftdinosaur-inference

Repository files navigation

FT-DINOSAUR Inference

This repository contains inference code for the FT-DINOSAUR model introduced in Zero-Shot Object-Centric Representation Learning.

Setup

This package only requires torch and torchvision as dependencies. Simply install the package using pip:

pip install .

We recommend using an environment manager like virtualenv or conda.

Quick Start

Import the package and print the list of defined models:

from ftdinosaur_inference import build_dinosaur
print(build_dinosaur.list_models)

We provide pre-trained checkpoints. Run the following to get a list of available checkpoints

print(build_dinosaur.list_checkpoints())

A model can be built by passing the model name to the build function:

model_name = "dinosaur_base_patch14_518_topk3.coco_dv2_ft_s7_300k+10k"
model = build_dinosaur.build(model_name)

By default, this will load a pre-trained checkpoint if it exists.

Then define the pre-processing pipeline that turns a numpy or pillow image into a torch tensor usable as the model's input:

preproc = build_dinosaur.build_preprocessing(model_name)

run the model:

from PIL import Image
image = Image.open("./images/gizmos.jpg").convert("RGB")
inp = preproc(image).unsqueeze(0)
outp = model(inp)

That's it! The model returns a dictionary with the output of the different modules. See the DINOSAUR module for an overview.

Example Notebook

An example notebook with mask visualization is contained in here. Install with the notebook optional dependencies and start a Jupyter server to view it:

pip install .[notebook]
jupyter notebook

List of Model Checkpoints

The model key follows the convention <model-architecture>.<checkpoint>.

Model Model Key ViT Input Size Description
FT-DINOSAUR dinosaur_small_patch14_224_topk3.coco_dv2_ft_s7_300k small14 224x224 Top-k decoding
FT-DINOSAUR dinosaur_small_patch14_518_topk3.coco_dv2_ft_s7_300k+10k small14 518x518 Top-k decoding, hi-res finetuned
FT-DINOSAUR dinosaur_base_patch14_224_topk3.coco_dv2_ft_s7_300k base14 224x224 Top-k decoding.
FT-DINOSAUR dinosaur_base_patch14_518_topk3.coco_dv2_ft_s7_300k+10k base14 518x518 Top-k decoding, hi-res finetuned

Development

Development dependencies can be installed using

pip install .[dev]

Tests can be run using pytest. The codebase uses pre-commit to manage linting.

Citation

If you make use of this work, please cite us:

@article{Didolkar2024ZeroShotOCRL,
  title={Zero-Shot Object-Centric Representation Learning},
  author={Didolkar, Aniket and Zadaianchuk, Andrii and Goyal, Anirudh and Mozer, Mike and Bengio, Yoshua and Martius, Georg and Seitzer, Maximilian},
  year={2024},
  journal={arXiv:2408.09162},
  url={https://arxiv.org/abs/2408.09162}
}

Licenses

This codebase is released under Apache License 2.0.

This codebase adapts parts of the timm library and the OCLF library. The concerned source files contain a note of their origin. Both are released under Apache License 2.0. We like to thank their authors!

The example image is from the COCO dataset, originally from Flickr, under CC BY-NC-SA 2.0 license. The "Gizmos" image is from Greff, van Steenkiste, Schmidhuber, 2020: On the Binding Problem in Artificial Neural Networks, under CC BY-SA 4.0 license.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors