Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 90 additions & 14 deletions comfy/context_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,36 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[


class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)

def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
idx = tuple([slice(None)] * dim + [self.index_list])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
window[idx] = full[idx]
return window.to(device)

def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = [slice(None)] * dim + [self.index_list]
idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add
return full

def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)


class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
Expand All @@ -94,7 +104,8 @@ class ContextFuseMethod:

ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
Expand All @@ -103,13 +114,18 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows

self.callbacks = {}

def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True
return False

Expand All @@ -123,6 +139,11 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
Expand All @@ -146,12 +167,19 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
Expand All @@ -164,7 +192,7 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
return resized_cond

def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
Expand All @@ -173,7 +201,7 @@ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows

def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
Expand Down Expand Up @@ -250,8 +278,8 @@ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
Expand Down Expand Up @@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)


def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])

return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)


def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)


def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
Expand Down Expand Up @@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta


# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim]
delta = context_length - context_overlap

for start_idx in range(0, latent_video_length - context_length, delta):
place_idx = start_idx + context_length

actual_delta = min(delta, latent_video_length - place_idx)
if actual_delta <= 0:
break

list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx

source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)

noise[tuple(target_slice)] = noise[tuple(source_slice)]

return noise
22 changes: 18 additions & 4 deletions comfy_extras/nodes_context_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def define_schema(cls) -> io.Schema:
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
Expand All @@ -34,7 +37,8 @@ def define_schema(cls) -> io.Schema:
)

@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
Expand All @@ -43,9 +47,15 @@ def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
dim=dim)
dim=dim,
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows
)
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time
comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model)

class WanContextWindowsManualNode(ContextWindowsManualNode):
Expand All @@ -68,14 +78,18 @@ def define_schema(cls) -> io.Schema:
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
]
return schema

@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)


class ContextWindowsExtension(ComfyExtension):
Expand Down