|
394 | 394 | ")\n", |
395 | 395 | "model.to(device)\n", |
396 | 396 | "\n", |
| 397 | + "num_train_timesteps = 1000\n", |
397 | 398 | "scheduler = DDPMScheduler(\n", |
398 | | - " num_train_timesteps=1000,\n", |
| 399 | + " num_train_timesteps=num_train_timesteps,\n", |
399 | 400 | ")\n", |
400 | 401 | "\n", |
401 | 402 | "optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n", |
|
433 | 434 | "\n", |
434 | 435 | " \"\"\"\n", |
435 | 436 | "\n", |
436 | | - " def __init__(self, condition_name: Optional[str] = None):\n", |
| 437 | + " def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n", |
437 | 438 | " self.condition_name = condition_name\n", |
| 439 | + " self.num_train_timesteps = num_train_timesteps\n", |
438 | 440 | "\n", |
439 | 441 | " def get_noise(self, images):\n", |
440 | 442 | " \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n", |
441 | 443 | " return torch.randn_like(images)\n", |
442 | 444 | "\n", |
| 445 | + " def get_timesteps(self, images):\n", |
| 446 | + " return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n", |
| 447 | + "\n", |
443 | 448 | " def __call__(\n", |
444 | 449 | " self,\n", |
445 | 450 | " batchdata: Dict[str, torch.Tensor],\n", |
|
449 | 454 | " ):\n", |
450 | 455 | " images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n", |
451 | 456 | " noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n", |
| 457 | + " timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n", |
452 | 458 | "\n", |
453 | | - " kwargs = {\"noise\": noise}\n", |
| 459 | + " kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n", |
454 | 460 | "\n", |
455 | 461 | " if self.condition_name is not None and isinstance(batchdata, Mapping):\n", |
456 | 462 | " kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n", |
|
2159 | 2165 | " val_data_loader=val_loader,\n", |
2160 | 2166 | " network=model,\n", |
2161 | 2167 | " inferer=inferer,\n", |
2162 | | - " prepare_batch=DiffusionPrepareBatch(),\n", |
| 2168 | + " prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n", |
2163 | 2169 | " key_val_metric={\"val_mean_abs_error\": MeanAbsoluteError(output_transform=from_engine([\"pred\", \"label\"]))},\n", |
2164 | 2170 | " val_handlers=val_handlers,\n", |
2165 | 2171 | ")\n", |
|
2178 | 2184 | " optimizer=optimizer,\n", |
2179 | 2185 | " loss_function=torch.nn.MSELoss(),\n", |
2180 | 2186 | " inferer=inferer,\n", |
2181 | | - " prepare_batch=DiffusionPrepareBatch(),\n", |
| 2187 | + " prepare_batch=DiffusionPrepareBatch(num_train_timesteps=num_train_timesteps),\n", |
2182 | 2188 | " key_train_metric={\"train_acc\": MeanSquaredError(output_transform=from_engine([\"pred\", \"label\"]))},\n", |
2183 | 2189 | " train_handlers=train_handlers,\n", |
2184 | 2190 | ")\n", |
|
0 commit comments