Skip to content

tedlasai/learn2refocus

Repository files navigation

Learning to Refocus with Video Diffusion Models

SaiKiran Tedla, Zhoutong Zhang, Xuaner Zhang, Shumian Xin
Adobe Research & York University

🤗 Demo

📂 Data


📌 Citation

If you use our dataset, code, or model, please cite:

@inproceedings{Tedla2025Refocus,
  title={{Learning to Refocus with Video Diffusion Models}},
  author={{Tedla, SaiKiran and Zhang, Zhoutong and Zhang, Xuaner and Xin, Shumian}},
  booktitle={{Proceedings of the ACM SIGGRAPH Asia Conference}},
  year={{2025}}
}

🚀 Getting Started

This guide explains how to train and evaluate our video diffusion model for refocusing. Note, that we also provide a simple hugging face demo for quickly testing our method.


🔧 Environment Setup

conda env create -f setup/environment.yaml
conda activate refocus
python setup/download_svd_weights.py
python setup/download_checkpoints.py
  • Install PyTorch and all dependencies listed in the YAML file.
  • Create a Weights & Biases (wandb) account for experiment tracking.
    Update the wandb credentials in training/configs with your login info.
  • After running setup/download_svd_weights.py, you should have a folder named svdh at the project root containing the Stable Video Diffusion model weights.
  • After running setup/download_checkpoints.py, you should have a folder named checkpoints/checkpoints-200000 at the project root containing our finetuned weights.

🧪 Testing (In-the-Wild)

We provide two methods to run. A simple inference method that takes the image path and input focal position (0-8) that corresponds to iPhone API 0 - 0.8 (5cm to infinity depth). You can also specify the output_dir to any path of your choice.

conda activate refocus
python ./simple_inference.py --image_path /datasets/sai/focal-burst-learning/svd/photos/img_0.jpg --input_focal_position 0 --output_dir outputs/simple_inference

And a method that utilizes accelerate and lets use multigpu.

To test on real-world photos, place your images in the photos/ directory and run (requires about 20-25GB memory depending on image sizes):

conda activate refocus
accelerate launch --config_file training/configs/accelerator_config.yaml \
  --multi_gpu training/svd_runner.py \
  --config training/configs/outside_photos.yaml

Results and visualizations will be saved to the directory specified by output_dir (default: output_dir/outside_photos/).
Each output folder contains the generated focal stacks corresponding to the input image’s focal positions.

Note: If you want to run with lower memory, you can set max_pixels = 500000 in utils.py to a lower number. This will run the model at a lower resolution.


🧪 Testing (Focal Stack Dataset)

accelerate launch --config_file training/configs/accelerator_config.yaml --multi_gpu training/svd_runner.py --config training/configs/focal_stacks_test.yaml

Each output folder contains the generated focal stacks corresponding to the input image’s focal positions. Results and visualizations will be saved to the directory specified by output_dir (default: output_dir/focal_stacks_test/).


Dataset

Our dataset includes raw DNGs from all five cameras, along with rendered images at both full and midsize resolutions. We provide ZIP archives for each portion of the dataset.

  1. fullsize_dng
    Contains ZIP files with the raw DNG images and associated capture metadata.
    Each camera has five ZIP files due to the large file sizes.

  2. fullsize_undistorted
    Contains full-resolution rendered images for each camera, corrected for focal breathing and radial distortion to ensure consistent field of view (FOV).

  3. midsize_undistorted
    Contains the midsize (896x640) images on only the center camera. This is what we used for training and testing our model.

For training or testing, you only need the midsize_undistorted folder.
Place this folder anywhere on your computer, and set the data_dir field in your YAML configuration to point to this path.

Example: data_folder: "/datasets/sai/scenes_merged_midsize_undistorted"


🏋️‍♂️ Training

Set the following paths in your YAML config (feel free to change others paths to match your configuration):

data_folder: 
splits_dir: 
wandb_project: "RefocusingSVD"
run_name: "focal_stacks_train"

To train our model, run:

accelerate launch --config_file training/configs/accelerator_config.yaml --multi_gpu training/svd_runner.py --config training/configs/focal_stacks_train.yaml

Checkpoint will be in outputs/focal_stacks_train


📜 Notes

  • Checkpoints are available on the project page.
  • Dataset download links will be added soon.
  • We utilize extra/compute_metrics.py to compute all metrics for this project.

📨 Contact

For questions or issues, please reach out through the project page or contact Sai Tedla.

About

Code for Learning to Refocus with Video Diffusion Models - SIGGRAPH ASIA 2025

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published