Skip to content

Commit 7f1f949

Browse files
committed
Allow freenoise to work on other dims, handle 4D batch timestep
Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now.
1 parent 473cb8d commit 7f1f949

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

comfy/context_windows.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
192192
return resized_cond
193193

194194
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
195-
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
195+
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
196196
matches = torch.nonzero(mask)
197197
if torch.numel(matches) == 0:
198198
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@@ -324,7 +324,7 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
324324
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
325325
if not handler.freenoise:
326326
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
327-
noise = apply_freenoise(noise, handler.context_length, handler.context_overlap, extra_args["seed"])
327+
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
328328

329329
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
330330

@@ -591,24 +591,26 @@ def shift_window_to_end(window: list[int], num_frames: int):
591591

592592

593593
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
594-
def apply_freenoise(noise: torch.Tensor, context_length: int, context_overlap: int, seed: int):
594+
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
595595
logging.info("Context windows: Applying FreeNoise")
596-
generator = torch.manual_seed(seed)
597-
latent_video_length = noise.shape[2]
596+
generator = torch.Generator(device='cpu').manual_seed(seed)
597+
latent_video_length = noise.shape[dim]
598598
delta = context_length - context_overlap
599-
for start_idx in range(0, latent_video_length-context_length, delta):
599+
600+
for start_idx in range(0, latent_video_length - context_length, delta):
600601
place_idx = start_idx + context_length
601-
if place_idx >= latent_video_length:
602-
break
603-
end_idx = place_idx - 1
604602

605-
if end_idx + delta >= latent_video_length:
606-
final_delta = latent_video_length - place_idx
607-
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
608-
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
609-
noise[:, :, place_idx:place_idx + final_delta] = noise[:, :, list_idx]
603+
actual_delta = min(delta, latent_video_length - place_idx)
604+
if actual_delta <= 0:
610605
break
611-
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
612-
list_idx = list_idx[torch.randperm(delta, generator=generator)]
613-
noise[:, :, place_idx:place_idx + delta] = noise[:, :, list_idx]
606+
607+
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
608+
609+
source_slice = [slice(None)] * noise.ndim
610+
source_slice[dim] = list_idx
611+
target_slice = [slice(None)] * noise.ndim
612+
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
613+
614+
noise[tuple(target_slice)] = noise[tuple(source_slice)]
615+
614616
return noise

0 commit comments

Comments
 (0)