Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 3dc9b86

Browse files
authored
Add v_prediction and get_velocity method (#134)
* Add v_prediction option and get_velocity method Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * [WIP] Add changes in the inferer to use v_prediction Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Add v_prediction tutorial (#134) Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Change inferer usage to be compatible with new version (#134) Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> * Update tutorials(#134) Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]> Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
1 parent 6151176 commit 3dc9b86

14 files changed

+1464
-28
lines changed

generative/inferers/inferer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __call__(
3939
inputs: torch.Tensor,
4040
diffusion_model: Callable[..., torch.Tensor],
4141
noise: torch.Tensor,
42+
timesteps: torch.Tensor,
4243
condition: Optional[torch.Tensor] = None,
4344
) -> torch.Tensor:
4445
"""
@@ -48,10 +49,9 @@ def __call__(
4849
inputs: Input image to which noise is added.
4950
diffusion_model: diffusion model.
5051
noise: random noise, of the same shape as the input.
52+
timesteps: random timesteps.
5153
condition: Conditioning for network input.
5254
"""
53-
num_timesteps = self.scheduler.num_train_timesteps
54-
timesteps = torch.randint(0, num_timesteps, (inputs.shape[0],), device=inputs.device).long()
5555
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
5656
prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)
5757

@@ -123,6 +123,7 @@ def __call__(
123123
autoencoder_model: Callable[..., torch.Tensor],
124124
diffusion_model: Callable[..., torch.Tensor],
125125
noise: torch.Tensor,
126+
timesteps: torch.Tensor,
126127
condition: Optional[torch.Tensor] = None,
127128
) -> torch.Tensor:
128129
"""
@@ -133,6 +134,7 @@ def __call__(
133134
autoencoder_model: first stage model.
134135
diffusion_model: diffusion model.
135136
noise: random noise, of the same shape as the latent representation.
137+
timesteps: random timesteps.
136138
condition: conditioning for network input.
137139
"""
138140
with torch.no_grad():
@@ -142,6 +144,7 @@ def __call__(
142144
inputs=latent,
143145
diffusion_model=diffusion_model,
144146
noise=noise,
147+
timesteps=timesteps,
145148
condition=condition,
146149
)
147150

generative/networks/schedulers/ddim.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class DDIMScheduler(nn.Module):
5555
steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
5656
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
5757
stable diffusion.
58+
prediction_type: prediction type of the scheduler function, one of `epsilon` (predicting the noise of the
59+
diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
60+
https://imagen.research.google/video/paper.pdf)
5861
"""
5962

6063
def __init__(
@@ -66,6 +69,7 @@ def __init__(
6669
clip_sample: bool = True,
6770
set_alpha_to_one: bool = True,
6871
steps_offset: int = 0,
72+
prediction_type: str = "epsilon",
6973
) -> None:
7074
super().__init__()
7175
self.beta_schedule = beta_schedule
@@ -79,6 +83,12 @@ def __init__(
7983
else:
8084
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
8185

86+
if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
87+
raise ValueError(
88+
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
89+
)
90+
91+
self.prediction_type = prediction_type
8292
self.num_train_timesteps = num_train_timesteps
8393
self.alphas = 1.0 - self.betas
8494
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -171,7 +181,14 @@ def step(
171181

172182
# 3. compute predicted original sample from predicted noise also called
173183
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
174-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
184+
if self.prediction_type == "epsilon":
185+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
186+
elif self.prediction_type == "sample":
187+
pred_original_sample = model_output
188+
elif self.prediction_type == "v_prediction":
189+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
190+
# predict V
191+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
175192

176193
# 4. Clip "predicted x_0"
177194
if self.clip_sample:
@@ -231,3 +248,21 @@ def add_noise(
231248

232249
noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
233250
return noisy_samples
251+
252+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
253+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
254+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
255+
timesteps = timesteps.to(sample.device)
256+
257+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
258+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
259+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
260+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
261+
262+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
263+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
264+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
265+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
266+
267+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
268+
return velocity

generative/networks/schedulers/ddpm.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
beta_schedule: str = "linear",
6262
variance_type: str = "fixed_small",
6363
clip_sample: bool = True,
64+
prediction_type: str = "epsilon",
6465
) -> None:
6566
super().__init__()
6667
self.beta_schedule = beta_schedule
@@ -74,6 +75,13 @@ def __init__(
7475
else:
7576
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
7677

78+
if prediction_type.lower() not in ["epsilon", "sample", "v_prediction"]:
79+
raise ValueError(
80+
f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`"
81+
)
82+
83+
self.prediction_type = prediction_type
84+
7785
self.num_train_timesteps = num_train_timesteps
7886
self.alphas = 1.0 - self.betas
7987
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -170,10 +178,12 @@ def step(
170178

171179
# 2. compute predicted original sample from predicted noise also called
172180
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
173-
if predict_epsilon:
181+
if self.prediction_type == "epsilon":
174182
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
175-
else:
183+
elif self.prediction_type == "sample":
176184
pred_original_sample = model_output
185+
elif self.prediction_type == "v_prediction":
186+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
177187

178188
# 3. Clip "predicted x_0"
179189
if self.clip_sample:
@@ -233,3 +243,21 @@ def add_noise(
233243

234244
noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
235245
return noisy_samples
246+
247+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
248+
# Make sure alphas_cumprod and timestep have same device and dtype as sample
249+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
250+
timesteps = timesteps.to(sample.device)
251+
252+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
253+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
254+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
255+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
256+
257+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
258+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
259+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
260+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
261+
262+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
263+
return velocity

tests/test_diffusion_inferer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def test_call(self, model_params, input_shape):
6363
)
6464
inferer = DiffusionInferer(scheduler=scheduler)
6565
scheduler.set_timesteps(num_inference_steps=10)
66-
sample = inferer(inputs=input, noise=noise, diffusion_model=model)
66+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
67+
sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)
6768
self.assertEqual(sample.shape, input_shape)
6869

6970
@parameterized.expand(TEST_CASES)

tests/test_latent_diffusion_inferer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def test_prediction_shape(self, model_type, autoencoder_params, stage_2_params,
9898
)
9999
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
100100
scheduler.set_timesteps(num_inference_steps=10)
101-
prediction = inferer(inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise)
101+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
102+
prediction = inferer(
103+
inputs=input, autoencoder_model=autoencoder_model, diffusion_model=stage_2, noise=noise, timesteps=timesteps
104+
)
102105
self.assertEqual(prediction.shape, latent_shape)
103106

104107
@parameterized.expand(TEST_CASES)

tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,8 +780,13 @@
780780
" # Generate random noise\n",
781781
" noise = torch.randn_like(images).to(device)\n",
782782
"\n",
783+
" # Create timesteps\n",
784+
" timesteps = torch.randint(\n",
785+
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
786+
" ).long()\n",
787+
"\n",
783788
" # Get model prediction\n",
784-
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
789+
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
785790
"\n",
786791
" loss = F.mse_loss(noise_pred.float(), noise.float())\n",
787792
"\n",
@@ -806,7 +811,10 @@
806811
" with torch.no_grad():\n",
807812
" with autocast(enabled=True):\n",
808813
" noise = torch.randn_like(images).to(device)\n",
809-
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)\n",
814+
" timesteps = torch.randint(\n",
815+
" 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device\n",
816+
" ).long()\n",
817+
" noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n",
810818
" val_loss = F.mse_loss(noise_pred.float(), noise.float())\n",
811819
"\n",
812820
" val_epoch_loss += val_loss.item()\n",

tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,13 @@
207207
# Generate random noise
208208
noise = torch.randn_like(images).to(device)
209209

210+
# Create timesteps
211+
timesteps = torch.randint(
212+
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
213+
).long()
214+
210215
# Get model prediction
211-
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
216+
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
212217

213218
loss = F.mse_loss(noise_pred.float(), noise.float())
214219

@@ -233,7 +238,10 @@
233238
with torch.no_grad():
234239
with autocast(enabled=True):
235240
noise = torch.randn_like(images).to(device)
236-
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise)
241+
timesteps = torch.randint(
242+
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
243+
).long()
244+
noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)
237245
val_loss = F.mse_loss(noise_pred.float(), noise.float())
238246

239247
val_epoch_loss += val_loss.item()

tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,9 @@
394394
")\n",
395395
"model.to(device)\n",
396396
"\n",
397+
"num_train_timesteps = 1000\n",
397398
"scheduler = DDPMScheduler(\n",
398-
" num_train_timesteps=1000,\n",
399+
" num_train_timesteps=num_train_timesteps,\n",
399400
")\n",
400401
"\n",
401402
"optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n",
@@ -433,13 +434,17 @@
433434
"\n",
434435
" \"\"\"\n",
435436
"\n",
436-
" def __init__(self, condition_name: Optional[str] = None):\n",
437+
" def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n",
437438
" self.condition_name = condition_name\n",
439+
" self.num_train_timesteps = num_train_timesteps\n",
438440
"\n",
439441
" def get_noise(self, images):\n",
440442
" \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n",
441443
" return torch.randn_like(images)\n",
442444
"\n",
445+
" def get_timesteps(self, images):\n",
446+
" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n",
447+
"\n",
443448
" def __call__(\n",
444449
" self,\n",
445450
" batchdata: Dict[str, torch.Tensor],\n",
@@ -449,8 +454,9 @@
449454
" ):\n",
450455
" images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n",
451456
" noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n",
457+
" timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n",
452458
"\n",
453-
" kwargs = {\"noise\": noise}\n",
459+
" kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n",
454460
"\n",
455461
" if self.condition_name is not None and isinstance(batchdata, Mapping):\n",
456462
" kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n",
@@ -2159,7 +2165,7 @@
21592165
" val_data_loader=val_loader,\n",
21602166
" network=model,\n",
21612167
" inferer=inferer,\n",
2162-
" prepare_batch=DiffusionPrepareBatch(),\n",
2168+
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
21632169
" key_val_metric={\"val_mean_abs_error\": MeanAbsoluteError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
21642170
" val_handlers=val_handlers,\n",
21652171
")\n",
@@ -2178,7 +2184,7 @@
21782184
" optimizer=optimizer,\n",
21792185
" loss_function=torch.nn.MSELoss(),\n",
21802186
" inferer=inferer,\n",
2181-
" prepare_batch=DiffusionPrepareBatch(),\n",
2187+
" prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n",
21822188
" key_train_metric={\"train_acc\": MeanSquaredError(output_transform=from_engine([\"pred\", \"label\"]))},\n",
21832189
" train_handlers=train_handlers,\n",
21842190
")\n",

tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@
177177
)
178178
model.to(device)
179179

180+
num_train_timesteps = 1000
180181
scheduler = DDPMScheduler(
181-
num_train_timesteps=1000,
182+
num_train_timesteps=num_train_timesteps,
182183
)
183184

184185
optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)
@@ -203,13 +204,17 @@ class DiffusionPrepareBatch(PrepareBatch):
203204
204205
"""
205206

206-
def __init__(self, condition_name: Optional[str] = None):
207+
def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):
207208
self.condition_name = condition_name
209+
self.num_train_timesteps = num_train_timesteps
208210

209211
def get_noise(self, images):
210212
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
211213
return torch.randn_like(images)
212214

215+
def get_timesteps(self, images):
216+
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
217+
213218
def __call__(
214219
self,
215220
batchdata: Dict[str, torch.Tensor],
@@ -219,8 +224,9 @@ def __call__(
219224
):
220225
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
221226
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
227+
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
222228

223-
kwargs = {"noise": noise}
229+
kwargs = {"noise": noise, "timesteps": timesteps}
224230

225231
if self.condition_name is not None and isinstance(batchdata, Mapping):
226232
kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
@@ -244,7 +250,7 @@ def __call__(
244250
val_data_loader=val_loader,
245251
network=model,
246252
inferer=inferer,
247-
prepare_batch=DiffusionPrepareBatch(),
253+
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
248254
key_val_metric={"val_mean_abs_error": MeanAbsoluteError(output_transform=from_engine(["pred", "label"]))},
249255
val_handlers=val_handlers,
250256
)
@@ -263,7 +269,7 @@ def __call__(
263269
optimizer=optimizer,
264270
loss_function=torch.nn.MSELoss(),
265271
inferer=inferer,
266-
prepare_batch=DiffusionPrepareBatch(),
272+
prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),
267273
key_train_metric={"train_acc": MeanSquaredError(output_transform=from_engine(["pred", "label"]))},
268274
train_handlers=train_handlers,
269275
)

0 commit comments

Comments
 (0)