Skip to content

Commit 9d8dee0

Browse files
committed
feat(api-nodes): add Nano Banana Pro
1 parent cb96d4d commit 9d8dee0

File tree

3 files changed

+214
-9
lines changed

3 files changed

+214
-9
lines changed

comfy_api_nodes/apis/gemini_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class GeminiTextPart(BaseModel):
6868

6969

7070
class GeminiContent(BaseModel):
71-
parts: list[GeminiPart] = Field(...)
71+
parts: list[GeminiPart] = Field([])
7272
role: GeminiRole = Field(..., examples=["user"])
7373

7474

@@ -120,7 +120,7 @@ class GeminiGenerationConfig(BaseModel):
120120

121121
class GeminiImageConfig(BaseModel):
122122
aspectRatio: str | None = Field(None)
123-
resolution: str | None = Field(None)
123+
imageSize: str | None = Field(None)
124124

125125

126126
class GeminiImageGenerationConfig(GeminiGenerationConfig):
@@ -227,3 +227,4 @@ class GeminiGenerateContentResponse(BaseModel):
227227
candidates: list[GeminiCandidate] | None = Field(None)
228228
promptFeedback: GeminiPromptFeedback | None = Field(None)
229229
usageMetadata: GeminiUsageMetadata | None = Field(None)
230+
modelVersion: str | None = Field(None)

comfy_api_nodes/nodes_gemini.py

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
GeminiMimeType,
3030
GeminiPart,
3131
GeminiRole,
32+
Modality,
3233
)
3334
from comfy_api_nodes.util import (
3435
ApiEndpoint,
3536
audio_to_base64_string,
3637
bytesio_to_image_tensor,
38+
get_number_of_images,
3739
sync_op,
3840
tensor_to_base64_string,
3941
validate_string,
@@ -147,6 +149,49 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
147149
return torch.cat(image_tensors, dim=0)
148150

149151

152+
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
153+
if not response.modelVersion:
154+
return None
155+
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
156+
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
157+
input_tokens_price = 1.25
158+
output_text_tokens_price = 10.0
159+
output_image_tokens_price = 0.0
160+
elif response.modelVersion in (
161+
"gemini-2.5-flash-preview-04-17",
162+
"gemini-2.5-flash",
163+
):
164+
input_tokens_price = 0.30
165+
output_text_tokens_price = 2.50
166+
output_image_tokens_price = 0.0
167+
elif response.modelVersion in (
168+
"gemini-2.5-flash-image-preview",
169+
"gemini-2.5-flash-image",
170+
):
171+
input_tokens_price = 0.30
172+
output_text_tokens_price = 2.50
173+
output_image_tokens_price = 30.0
174+
elif response.modelVersion == "gemini-3-pro-preview":
175+
input_tokens_price = 2
176+
output_text_tokens_price = 12.0
177+
output_image_tokens_price = 0.0
178+
elif response.modelVersion == "gemini-3-pro-image-preview":
179+
input_tokens_price = 2
180+
output_text_tokens_price = 12.0
181+
output_image_tokens_price = 120.0
182+
else:
183+
return None
184+
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
185+
for i in response.usageMetadata.candidatesTokensDetails:
186+
if i.modality == Modality.IMAGE:
187+
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
188+
else:
189+
final_price += output_text_tokens_price * i.tokenCount
190+
if response.usageMetadata.thoughtsTokenCount:
191+
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
192+
return final_price / 1_000_000.0
193+
194+
150195
class GeminiNode(IO.ComfyNode):
151196
"""
152197
Node to generate text responses from a Gemini model.
@@ -314,6 +359,7 @@ async def execute(
314359
]
315360
),
316361
response_model=GeminiGenerateContentResponse,
362+
price_extractor=calculate_tokens_price,
317363
)
318364

319365
output_text = get_text_from_response(response)
@@ -476,6 +522,13 @@ def define_schema(cls):
476522
"or otherwise generates 1:1 squares.",
477523
optional=True,
478524
),
525+
IO.Combo.Input(
526+
"response_modalities",
527+
options=["IMAGE+TEXT", "IMAGE"],
528+
tooltip="Choose 'IMAGE' for image-only output, or "
529+
"'IMAGE+TEXT' to return both the generated image and a text response.",
530+
optional=True,
531+
),
479532
],
480533
outputs=[
481534
IO.Image.Output(),
@@ -498,6 +551,7 @@ async def execute(
498551
images: torch.Tensor | None = None,
499552
files: list[GeminiPart] | None = None,
500553
aspect_ratio: str = "auto",
554+
response_modalities: str = "IMAGE+TEXT",
501555
) -> IO.NodeOutput:
502556
validate_string(prompt, strip_whitespace=True, min_length=1)
503557
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@@ -520,17 +574,16 @@ async def execute(
520574
GeminiContent(role=GeminiRole.user, parts=parts),
521575
],
522576
generationConfig=GeminiImageGenerationConfig(
523-
responseModalities=["TEXT", "IMAGE"],
577+
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
524578
imageConfig=None if aspect_ratio == "auto" else image_config,
525579
),
526580
),
527581
response_model=GeminiGenerateContentResponse,
582+
price_extractor=calculate_tokens_price,
528583
)
529584

530-
output_image = get_image_from_response(response)
531585
output_text = get_text_from_response(response)
532586
if output_text:
533-
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
534587
render_spec = {
535588
"node_id": cls.hidden.unique_id,
536589
"component": "ChatHistoryWidget",
@@ -551,9 +604,150 @@ async def execute(
551604
"display_component",
552605
render_spec,
553606
)
607+
return IO.NodeOutput(get_image_from_response(response), output_text)
608+
609+
610+
class GeminiImage2(IO.ComfyNode):
554611

555-
output_text = output_text or "Empty response from Gemini model..."
556-
return IO.NodeOutput(output_image, output_text)
612+
@classmethod
613+
def define_schema(cls):
614+
return IO.Schema(
615+
node_id="GeminiImage2Node",
616+
display_name="Nano Banana Pro (Google Gemini Image)",
617+
category="api node/image/Gemini",
618+
description="Generate or edit images synchronously via Google Vertex API.",
619+
inputs=[
620+
IO.String.Input(
621+
"prompt",
622+
multiline=True,
623+
tooltip="Text prompt describing the image to generate or the edits to apply. "
624+
"Include any constraints, styles, or details the model should follow.",
625+
default="",
626+
),
627+
IO.Combo.Input(
628+
"model",
629+
options=["gemini-3-pro-image-preview"],
630+
),
631+
IO.Int.Input(
632+
"seed",
633+
default=42,
634+
min=0,
635+
max=0xFFFFFFFFFFFFFFFF,
636+
control_after_generate=True,
637+
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
638+
"the same response for repeated requests. Deterministic output isn't guaranteed. "
639+
"Also, changing the model or parameter settings, such as the temperature, "
640+
"can cause variations in the response even when you use the same seed value. "
641+
"By default, a random seed value is used.",
642+
),
643+
IO.Combo.Input(
644+
"aspect_ratio",
645+
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
646+
default="auto",
647+
tooltip="If set to 'auto', matches your input image's aspect ratio; "
648+
"if no image is provided, generates a 1:1 square.",
649+
),
650+
IO.Combo.Input(
651+
"resolution",
652+
options=["1K", "2K", "4K"],
653+
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
654+
),
655+
IO.Combo.Input(
656+
"response_modalities",
657+
options=["IMAGE+TEXT", "IMAGE"],
658+
tooltip="Choose 'IMAGE' for image-only output, or "
659+
"'IMAGE+TEXT' to return both the generated image and a text response.",
660+
),
661+
IO.Image.Input(
662+
"images",
663+
optional=True,
664+
tooltip="Optional reference image(s). "
665+
"To include multiple images, use the Batch Images node (up to 14).",
666+
),
667+
IO.Custom("GEMINI_INPUT_FILES").Input(
668+
"files",
669+
optional=True,
670+
tooltip="Optional file(s) to use as context for the model. "
671+
"Accepts inputs from the Gemini Generate Content Input Files node.",
672+
),
673+
],
674+
outputs=[
675+
IO.Image.Output(),
676+
IO.String.Output(),
677+
],
678+
hidden=[
679+
IO.Hidden.auth_token_comfy_org,
680+
IO.Hidden.api_key_comfy_org,
681+
IO.Hidden.unique_id,
682+
],
683+
is_api_node=True,
684+
)
685+
686+
@classmethod
687+
async def execute(
688+
cls,
689+
prompt: str,
690+
model: str,
691+
seed: int,
692+
aspect_ratio: str,
693+
resolution: str,
694+
response_modalities: str,
695+
images: torch.Tensor | None = None,
696+
files: list[GeminiPart] | None = None,
697+
) -> IO.NodeOutput:
698+
validate_string(prompt, strip_whitespace=True, min_length=1)
699+
700+
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
701+
if images is not None:
702+
if get_number_of_images(images) > 6:
703+
raise ValueError("The current maximum number of supported images is 14.")
704+
parts.extend(create_image_parts(images))
705+
if files is not None:
706+
parts.extend(files)
707+
708+
image_config = GeminiImageConfig(imageSize=resolution)
709+
if aspect_ratio != "auto":
710+
image_config.aspectRatio = aspect_ratio
711+
712+
response = await sync_op(
713+
cls,
714+
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
715+
data=GeminiImageGenerateContentRequest(
716+
contents=[
717+
GeminiContent(role=GeminiRole.user, parts=parts),
718+
],
719+
generationConfig=GeminiImageGenerationConfig(
720+
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
721+
imageConfig=image_config,
722+
),
723+
),
724+
response_model=GeminiGenerateContentResponse,
725+
price_extractor=calculate_tokens_price,
726+
)
727+
728+
output_text = get_text_from_response(response)
729+
if output_text:
730+
render_spec = {
731+
"node_id": cls.hidden.unique_id,
732+
"component": "ChatHistoryWidget",
733+
"props": {
734+
"history": json.dumps(
735+
[
736+
{
737+
"prompt": prompt,
738+
"response": output_text,
739+
"response_id": str(uuid.uuid4()),
740+
"timestamp": time.time(),
741+
}
742+
]
743+
),
744+
},
745+
}
746+
PromptServer.instance.send_sync(
747+
"display_component",
748+
render_spec,
749+
)
750+
return IO.NodeOutput(get_image_from_response(response), output_text)
557751

558752

559753
class GeminiExtension(ComfyExtension):
@@ -562,6 +756,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
562756
return [
563757
GeminiNode,
564758
GeminiImage,
759+
GeminiImage2,
565760
GeminiInputFiles,
566761
]
567762

comfy_api_nodes/util/client.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class _RequestConfig:
6363
estimated_total: Optional[int] = None
6464
final_label_on_success: Optional[str] = "Completed"
6565
progress_origin_ts: Optional[float] = None
66+
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
6667

6768

6869
@dataclass
@@ -87,6 +88,7 @@ async def sync_op(
8788
endpoint: ApiEndpoint,
8889
*,
8990
response_model: Type[M],
91+
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
9092
data: Optional[BaseModel] = None,
9193
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
9294
content_type: str = "application/json",
@@ -104,6 +106,7 @@ async def sync_op(
104106
raw = await sync_op_raw(
105107
cls,
106108
endpoint,
109+
price_extractor=_wrap_model_extractor(response_model, price_extractor),
107110
data=data,
108111
files=files,
109112
content_type=content_type,
@@ -175,6 +178,7 @@ async def sync_op_raw(
175178
cls: type[IO.ComfyNode],
176179
endpoint: ApiEndpoint,
177180
*,
181+
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
178182
data: Optional[Union[dict[str, Any], BaseModel]] = None,
179183
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
180184
content_type: str = "application/json",
@@ -216,6 +220,7 @@ async def sync_op_raw(
216220
estimated_total=estimated_duration,
217221
final_label_on_success=final_label_on_success,
218222
progress_origin_ts=progress_origin_ts,
223+
price_extractor=price_extractor,
219224
)
220225
return await _request_base(cfg, expect_binary=as_binary)
221226

@@ -425,7 +430,8 @@ def _display_text(
425430
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
426431
if price is not None:
427432
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
428-
display_lines.append(f"Price: ${p}")
433+
if p != "0":
434+
display_lines.append(f"Price: ${p}")
429435
if text is not None:
430436
display_lines.append(text)
431437
if display_lines:
@@ -581,6 +587,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
581587
delay = cfg.retry_delay
582588
operation_succeeded: bool = False
583589
final_elapsed_seconds: Optional[int] = None
590+
extracted_price: Optional[float] = None
584591
while True:
585592
attempt += 1
586593
stop_event = asyncio.Event()
@@ -768,6 +775,8 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
768775
except json.JSONDecodeError:
769776
payload = {"_raw": text}
770777
response_content_to_log = payload if isinstance(payload, dict) else text
778+
with contextlib.suppress(Exception):
779+
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
771780
operation_succeeded = True
772781
final_elapsed_seconds = int(time.monotonic() - start_time)
773782
try:
@@ -872,7 +881,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
872881
else int(time.monotonic() - start_time)
873882
),
874883
estimated_total=cfg.estimated_total,
875-
price=None,
884+
price=extracted_price,
876885
is_queued=False,
877886
processing_elapsed_seconds=final_elapsed_seconds,
878887
)

0 commit comments

Comments
 (0)