Skip to content

Commit c3c6fa5

Browse files
committed
add get_frame_count and get_frame_rate methods to VideoInput class
1 parent cbd68e3 commit c3c6fa5

File tree

3 files changed

+106
-9
lines changed

3 files changed

+106
-9
lines changed

comfy_api/latest/_input/video_types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from abc import ABC, abstractmethod
3+
from fractions import Fraction
34
from typing import Optional, Union, IO
45
import io
56
import av
@@ -72,6 +73,33 @@ def get_duration(self) -> float:
7273
frame_count = components.images.shape[0]
7374
return float(frame_count / components.frame_rate)
7475

76+
def get_frame_count(self) -> int:
77+
"""
78+
Returns the number of frames in the video.
79+
80+
Default implementation uses :meth:`get_components`, which may require
81+
loading all frames into memory. File-based implementations should
82+
override this method and use container/stream metadata instead.
83+
84+
Returns:
85+
Total number of frames as an integer.
86+
"""
87+
return int(self.get_components().images.shape[0])
88+
89+
def get_frame_rate(self) -> Fraction:
90+
"""
91+
Returns the frame rate of the video.
92+
93+
Default implementation materializes the video into memory via
94+
`get_components()`. Subclasses that can inspect the underlying
95+
container (e.g. `VideoFromFile`) should override this with a more
96+
efficient implementation.
97+
98+
Returns:
99+
Frame rate as a Fraction.
100+
"""
101+
return self.get_components().frame_rate
102+
75103
def get_container_format(self) -> str:
76104
"""
77105
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

comfy_api/latest/_input_impl/video_types.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,71 @@ def get_duration(self) -> float:
121121

122122
raise ValueError(f"Could not determine duration for file '{self.__file}'")
123123

124+
def get_frame_count(self) -> int:
125+
"""
126+
Returns the number of frames in the video without materializing them as
127+
torch tensors.
128+
"""
129+
if isinstance(self.__file, io.BytesIO):
130+
self.__file.seek(0)
131+
132+
with av.open(self.__file, mode="r") as container:
133+
video_stream = self._get_first_video_stream(container)
134+
# 1. Prefer the frames field if available
135+
if video_stream.frames and video_stream.frames > 0:
136+
return int(video_stream.frames)
137+
138+
# 2. Try to estimate from duration and average_rate using only metadata
139+
if container.duration is not None and video_stream.average_rate:
140+
duration_seconds = float(container.duration / av.time_base)
141+
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
142+
if estimated_frames > 0:
143+
return estimated_frames
144+
145+
if (
146+
getattr(video_stream, "duration", None) is not None
147+
and getattr(video_stream, "time_base", None) is not None
148+
and video_stream.average_rate
149+
):
150+
duration_seconds = float(video_stream.duration * video_stream.time_base)
151+
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
152+
if estimated_frames > 0:
153+
return estimated_frames
154+
155+
# 3. Last resort: decode frames and count them (streaming)
156+
frame_count = 0
157+
container.seek(0)
158+
for packet in container.demux(video_stream):
159+
for _ in packet.decode():
160+
frame_count += 1
161+
162+
if frame_count == 0:
163+
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
164+
return frame_count
165+
166+
def get_frame_rate(self) -> Fraction:
167+
"""
168+
Returns the average frame rate of the video using container metadata
169+
without decoding all frames.
170+
"""
171+
if isinstance(self.__file, io.BytesIO):
172+
self.__file.seek(0)
173+
174+
with av.open(self.__file, mode="r") as container:
175+
video_stream = self._get_first_video_stream(container)
176+
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
177+
if video_stream.average_rate:
178+
return Fraction(video_stream.average_rate)
179+
180+
# Fallback: estimate from frames + duration if available
181+
if video_stream.frames and container.duration:
182+
duration_seconds = float(container.duration / av.time_base)
183+
if duration_seconds > 0:
184+
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
185+
186+
# Last resort: match get_components_internal default
187+
return Fraction(1)
188+
124189
def get_container_format(self) -> str:
125190
"""
126191
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@@ -238,6 +303,13 @@ def save_to(
238303
packet.stream = stream_map[packet.stream]
239304
output_container.mux(packet)
240305

306+
def _get_first_video_stream(self, container: InputContainer):
307+
video_stream = next((s for s in container.streams if s.type == "video"), None)
308+
if video_stream is None:
309+
raise ValueError(f"No video stream found in file '{self.__file}'")
310+
return video_stream
311+
312+
241313
class VideoFromComponents(VideoInput):
242314
"""
243315
Class representing video input from tensors.

comfy_api_nodes/nodes_topaz.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import torch
66
from typing_extensions import override
77

8-
from comfy_api.input.video_types import VideoInput
9-
from comfy_api.latest import IO, ComfyExtension
8+
from comfy_api.latest import IO, ComfyExtension, Input
109
from comfy_api_nodes.apis import topaz_api
1110
from comfy_api_nodes.util import (
1211
ApiEndpoint,
@@ -282,7 +281,7 @@ def define_schema(cls):
282281
@classmethod
283282
async def execute(
284283
cls,
285-
video: VideoInput,
284+
video: Input.Video,
286285
upscaler_enabled: bool,
287286
upscaler_model: str,
288287
upscaler_resolution: str,
@@ -297,12 +296,10 @@ async def execute(
297296
) -> IO.NodeOutput:
298297
if upscaler_enabled is False and interpolation_enabled is False:
299298
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
299+
validate_container_format_is_mp4(video)
300300
src_width, src_height = video.get_dimensions()
301-
video_components = video.get_components()
302-
src_frame_rate = int(video_components.frame_rate)
301+
src_frame_rate = video.get_frame_rate()
303302
duration_sec = video.get_duration()
304-
estimated_frames = int(duration_sec * src_frame_rate)
305-
validate_container_format_is_mp4(video)
306303
src_video_stream = video.get_stream_source()
307304
target_width = src_width
308305
target_height = src_height
@@ -338,8 +335,8 @@ async def execute(
338335
container="mp4",
339336
size=get_fs_object_size(src_video_stream),
340337
duration=int(duration_sec),
341-
frameCount=estimated_frames,
342-
frameRate=src_frame_rate,
338+
frameCount=video.get_frame_count(),
339+
frameRate=int(src_frame_rate),
343340
resolution=topaz_api.Resolution(width=src_width, height=src_height),
344341
),
345342
filters=filters,

0 commit comments

Comments
 (0)