Skip to content

troubles with DiffUnet.forward_denoise() #602

@mathurinm

Description

@mathurinm

@mh-nguyen712 @matthieutrs said you could help :)

With @annegnx we're trying to use DiffUnet for diffusion with the DDPM scheme.
To circumvent the various scalings by 2 and $\sqrt{\bar \alpha}$ we're working directly with forward_diffusion(). However the results are not good. Do you see anything wrong in the following MCVE?

import deepinv
import torch

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

denoiser = deepinv.models.DiffUNet(large_model=False).to(device)
denoiser.eval()

alphas_bar = (denoiser.get_alpha_prod()[-1]**2).to(device)
res = []


with torch.no_grad():
    xt = torch.randn([1, 3, 128, 128], device=device)
    for t in tqdm(reversed(range(1, len(alphas_bar)))):
        abar = alphas_bar[t]
        alpha_t = abar / alphas_bar[t - 1]
        beta_t = 1 - alpha_t

        noise_pred = denoiser.forward_diffusion(
            xt, torch.tensor([t]).to(device))[:, :3]
        noise_coef = beta_t / (1 - abar).sqrt()
        sigma_t = ((1 - alphas_bar[t - 1]) / (1 - abar) * beta_t).sqrt()
        if t > 1:
            xt = (xt - noise_coef * noise_pred) / alpha_t.sqrt()
            noise = torch.randn_like(noise_pred)
            xt += sigma_t * noise
        else:
            xt = (xt - noise_coef * noise_pred) / alpha_t.sqrt()
        if t % 100 == 0:
            res.append(xt.clone().detach())


import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 9, figsize=(20, 5))
for xt, ax in zip(res, axes):
    ax.imshow((xt.detach().cpu().squeeze()).permute(1, 2, 0))
plt.tight_layout()
fig.savefig("result.png")

which generates :

Image

Thanks heaps

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions