Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions comfy_api/latest/_input/video_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
Expand Down Expand Up @@ -72,6 +73,33 @@ def get_duration(self) -> float:
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.

Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.

Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])

def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.

Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.

Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Expand Down
72 changes: 72 additions & 0 deletions comfy_api/latest/_input_impl/video_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,71 @@ def get_duration(self) -> float:

raise ValueError(f"Could not determine duration for file '{self.__file}'")

def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)

with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)

# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames

if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames

# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1

if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count

def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)

with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)

# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()

# Last resort: match get_components_internal default
return Fraction(1)

def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
Expand Down Expand Up @@ -238,6 +303,13 @@ def save_to(
packet.stream = stream_map[packet.stream]
output_container.mux(packet)

def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream


class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
Expand Down
13 changes: 5 additions & 8 deletions comfy_api_nodes/nodes_topaz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import torch
from typing_extensions import override

from comfy_api.input.video_types import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import (
ApiEndpoint,
Expand Down Expand Up @@ -282,7 +281,7 @@ def define_schema(cls):
@classmethod
async def execute(
cls,
video: VideoInput,
video: Input.Video,
upscaler_enabled: bool,
upscaler_model: str,
upscaler_resolution: str,
Expand All @@ -297,12 +296,10 @@ async def execute(
) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video)
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
Expand Down Expand Up @@ -338,7 +335,7 @@ async def execute(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=estimated_frames,
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height),
),
Expand Down