-
Notifications
You must be signed in to change notification settings - Fork 133
Closed
Description
@mh-nguyen712 @matthieutrs said you could help :)
With @annegnx we're trying to use DiffUnet for diffusion with the DDPM scheme.
To circumvent the various scalings by 2 and forward_diffusion(). However the results are not good. Do you see anything wrong in the following MCVE?
import deepinv
import torch
from tqdm import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"
denoiser = deepinv.models.DiffUNet(large_model=False).to(device)
denoiser.eval()
alphas_bar = (denoiser.get_alpha_prod()[-1]**2).to(device)
res = []
with torch.no_grad():
xt = torch.randn([1, 3, 128, 128], device=device)
for t in tqdm(reversed(range(1, len(alphas_bar)))):
abar = alphas_bar[t]
alpha_t = abar / alphas_bar[t - 1]
beta_t = 1 - alpha_t
noise_pred = denoiser.forward_diffusion(
xt, torch.tensor([t]).to(device))[:, :3]
noise_coef = beta_t / (1 - abar).sqrt()
sigma_t = ((1 - alphas_bar[t - 1]) / (1 - abar) * beta_t).sqrt()
if t > 1:
xt = (xt - noise_coef * noise_pred) / alpha_t.sqrt()
noise = torch.randn_like(noise_pred)
xt += sigma_t * noise
else:
xt = (xt - noise_coef * noise_pred) / alpha_t.sqrt()
if t % 100 == 0:
res.append(xt.clone().detach())
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 9, figsize=(20, 5))
for xt, ax in zip(res, axes):
ax.imshow((xt.detach().cpu().squeeze()).permute(1, 2, 0))
plt.tight_layout()
fig.savefig("result.png")which generates :
Thanks heaps
Metadata
Metadata
Assignees
Labels
No labels