SaiKiran Tedla, Zhoutong Zhang, Xuaner Zhang, Shumian Xin
Adobe Research & York University
🤗 Demo
📂 Data
If you use our dataset, code, or model, please cite:
@inproceedings{Tedla2025Refocus,
title={{Learning to Refocus with Video Diffusion Models}},
author={{Tedla, SaiKiran and Zhang, Zhoutong and Zhang, Xuaner and Xin, Shumian}},
booktitle={{Proceedings of the ACM SIGGRAPH Asia Conference}},
year={{2025}}
}This guide explains how to train and evaluate our video diffusion model for refocusing. Note, that we also provide a simple hugging face demo for quickly testing our method.
conda env create -f setup/environment.yaml
conda activate refocus
python setup/download_svd_weights.py
python setup/download_checkpoints.py- Install PyTorch and all dependencies listed in the YAML file.
- Create a Weights & Biases (wandb) account for experiment tracking.
Update the wandb credentials intraining/configswith your login info. - After running
setup/download_svd_weights.py, you should have a folder namedsvdhat the project root containing the Stable Video Diffusion model weights. - After running
setup/download_checkpoints.py, you should have a folder namedcheckpoints/checkpoints-200000at the project root containing our finetuned weights.
We provide two methods to run. A simple inference method that takes the image path and input focal position (0-8) that corresponds to iPhone API 0 - 0.8 (5cm to infinity depth). You can also specify the output_dir to any path of your choice.
conda activate refocus
python ./simple_inference.py --image_path /datasets/sai/focal-burst-learning/svd/photos/img_0.jpg --input_focal_position 0 --output_dir outputs/simple_inferenceAnd a method that utilizes accelerate and lets use multigpu.
To test on real-world photos, place your images in the photos/ directory and run (requires about 20-25GB memory depending on image sizes):
conda activate refocus
accelerate launch --config_file training/configs/accelerator_config.yaml \
--multi_gpu training/svd_runner.py \
--config training/configs/outside_photos.yamlResults and visualizations will be saved to the directory specified by output_dir (default: output_dir/outside_photos/).
Each output folder contains the generated focal stacks corresponding to the input image’s focal positions.
Note: If you want to run with lower memory, you can set max_pixels = 500000 in utils.py to a lower number. This will run the model at a lower resolution.
accelerate launch --config_file training/configs/accelerator_config.yaml --multi_gpu training/svd_runner.py --config training/configs/focal_stacks_test.yamlEach output folder contains the generated focal stacks corresponding to the input image’s focal positions.
Results and visualizations will be saved to the directory specified by output_dir (default: output_dir/focal_stacks_test/).
Our dataset includes raw DNGs from all five cameras, along with rendered images at both full and midsize resolutions. We provide ZIP archives for each portion of the dataset.
-
fullsize_dng
Contains ZIP files with the raw DNG images and associated capture metadata.
Each camera has five ZIP files due to the large file sizes. -
fullsize_undistorted
Contains full-resolution rendered images for each camera, corrected for focal breathing and radial distortion to ensure consistent field of view (FOV). -
midsize_undistorted
Contains the midsize (896x640) images on only the center camera. This is what we used for training and testing our model.
For training or testing, you only need the midsize_undistorted folder.
Place this folder anywhere on your computer, and set the data_dir field in your YAML configuration to point to this path.
Example: data_folder: "/datasets/sai/scenes_merged_midsize_undistorted"
Set the following paths in your YAML config (feel free to change others paths to match your configuration):
data_folder:
splits_dir:
wandb_project: "RefocusingSVD"
run_name: "focal_stacks_train"To train our model, run:
accelerate launch --config_file training/configs/accelerator_config.yaml --multi_gpu training/svd_runner.py --config training/configs/focal_stacks_train.yamlCheckpoint will be in outputs/focal_stacks_train
- Checkpoints are available on the project page.
- Dataset download links will be added soon.
- We utilize
extra/compute_metrics.pyto compute all metrics for this project.
For questions or issues, please reach out through the project page or contact Sai Tedla.