-
Notifications
You must be signed in to change notification settings - Fork 54
Description
@gokayfem I would do this as a pull request but Im a hobbyist and have never used that function so ill do it here.
I made a Gemma 3 node using your framework which works great. The large 24b model outputs amazing quality but the 4b is enough, Ive included uncensored versions of all 3 model sizes..
This offers some extra features like smart memory model offloading (balanced memory mode) as well as passing the max token limit to the prompt so the vlm knows its length limits avoiding long prompts being cut off.
I installed "git+https://github.com/huggingface/[email protected]" which seems to be a requirement but haven't tried without it.
Anyway i hope this makes its way into your awesome project, thanks so much for making your wonderful nodes. Feel free to close this feature request anytime.
gemma3.py
import torch
import psutil
import os
import numpy as np
from PIL import Image
from pathlib import Path
import folder_paths
import comfy.model_management as mm
import traceback
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
# Define the directory for saving Gemma-3 files
files_for_gemma3 = Path(folder_paths.folder_names_and_paths["LLavacheckpoints"][0][0]) / "files_for_gemma3"
files_for_gemma3.mkdir(parents=True, exist_ok=True)
# Model VRAM requirements (approximate, in GB)
MODEL_VRAM_REQUIREMENTS = {
"Gemma-3-4B": 6,
"Gemma-3-12B": 14,
"Gemma-3-27B": 28, # Added the largest model with its VRAM requirement
}
GEMMA3_MODELS = {
"Gemma-3-4B": "mlabonne/gemma-3-4b-it-abliterated",
"Gemma-3-12B": "mlabonne/gemma-3-12b-it-abliterated",
"Gemma-3-27B": "mlabonne/gemma-3-27b-it-abliterated", # Added the largest model
}
# Add Google's official model IDs
GOOGLE_GEMMA3_MODELS = [
"google/gemma-3-27b-it",
"google/gemma-3-12b-it",
"google/gemma-3-1b-it",
"google/gemma-3-4b-it"
]
# Define memory efficiency configurations
MEMORY_EFFICIENT_CONFIGS = {
"Balanced (GPU)": {
"device_map": "auto",
"torch_dtype": torch.bfloat16
},
"Maximum Performance (Full GPU)": {
"device_map": "cuda",
"torch_dtype": torch.bfloat16
},
"Maximum Savings (CPU Only)": {
"device_map": "cpu",
"torch_dtype": torch.float32
}
}
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
class SystemResources:
@staticmethod
def get_available_memory():
"""Get available system memory in GB"""
return psutil.virtual_memory().available / (1024 * 1024 * 1024)
@staticmethod
def get_available_vram():
"""Get available VRAM in GB"""
if not torch.cuda.is_available():
return 0
try:
torch.cuda.empty_cache() # Clear unused cached memory
return torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024)
except:
return 0
@staticmethod
def check_resources(model_name, memory_mode):
"""Check if system has enough resources for the model"""
required_vram = MODEL_VRAM_REQUIREMENTS.get(model_name, 0)
config = MEMORY_EFFICIENT_CONFIGS[memory_mode]
# Adjust VRAM requirements based on memory mode
if config["device_map"] == "cpu":
required_vram = 0 # CPU only mode
elif config["device_map"] == "auto":
# Auto device map may use less VRAM by offloading some layers
required_vram = required_vram * 0.8
available_vram = SystemResources.get_available_vram()
available_memory = SystemResources.get_available_memory()
# Need at least 2GB system memory buffer
required_system_memory = required_vram + 2
error_messages = []
if available_vram < required_vram and config["device_map"] != "cpu":
error_messages.append(
f"Insufficient VRAM: Model {model_name} requires approximately {required_vram:.1f}GB VRAM, "
f"but only {available_vram:.1f}GB available. "
"Try using 'Maximum Savings (CPU Only)' mode."
)
if available_memory < required_system_memory:
error_messages.append(
f"Insufficient system memory: Need at least {required_system_memory:.1f}GB, "
f"but only {available_memory:.1f}GB available"
)
return error_messages
class Gemma3Predictor:
def __init__(self, model_name, memory_mode="Balanced (GPU)"):
# Check system resources
error_messages = SystemResources.check_resources(model_name, memory_mode)
if error_messages:
raise RuntimeError("\n".join(error_messages))
# Get model configuration
model_id = GEMMA3_MODELS[model_name]
config = MEMORY_EFFICIENT_CONFIGS[memory_mode]
try:
# Use the ComfyUI model directory for caching
cache_dir = os.path.join(folder_paths.models_dir, "Gemma3")
os.makedirs(cache_dir, exist_ok=True)
print(f"Loading model {model_name} from {model_id}...")
print(f"Using memory mode: {memory_mode}")
print(f"Cache directory: {cache_dir}")
# Load processor and model
self.processor = AutoProcessor.from_pretrained(
model_id,
cache_dir=cache_dir,
trust_remote_code=True
)
self.model = Gemma3ForConditionalGeneration.from_pretrained(
model_id,
cache_dir=cache_dir,
device_map=config["device_map"],
torch_dtype=config["torch_dtype"],
trust_remote_code=True
).eval()
print(f"Successfully loaded {model_name} model")
except Exception as e:
error_detail = traceback.format_exc()
print(f"Detailed error: {error_detail}")
raise RuntimeError(f"Failed to load {model_name} model: {str(e)}\n\nDetails: {error_detail}")
def generate_predictions(self, prompt, image=None, max_new_tokens=512, temperature=0.7, top_p=0.9):
"""Generate text predictions using the Gemma model"""
try:
# Add token limit reminder to the prompt
token_limit = max(10, max_new_tokens - 10) # Ensure we don't go negative
prompt_with_limit = f"{prompt}\nYou must stay under {token_limit} tokens."
print(f"Added token limit of {token_limit} to prompt")
# Create the messages format for the model
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
}
]
# Add user message with image if provided
if image is not None:
image_pil = tensor2pil(image)
messages.append({
"role": "user",
"content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": prompt_with_limit}
]
})
else:
# Text-only prompt
messages.append({
"role": "user",
"content": [{"type": "text", "text": prompt_with_limit}]
})
# Process the input - explicitly setting dtype to bfloat16 like in the working implementation
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(self.model.device, dtype=torch.bfloat16) # Explicitly set the dtype
input_len = inputs["input_ids"].shape[-1]
print(f"Input length: {input_len} tokens")
print(f"Generating up to {max_new_tokens} new tokens...")
# Generate text - Use the simpler approach like in the working implementation
with torch.inference_mode():
# Simplify generation parameters to match the working implementation
generation = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False # Set to False like in the working implementation
)
# Extract the new tokens
generation = generation[0][input_len:]
print(f"Generated {len(generation)} tokens")
# Decode and return the result
decoded = self.processor.decode(generation, skip_special_tokens=True)
# Debug the output to see if it's being truncated
print(f"Output length: {len(decoded)} characters")
if len(decoded) < 50:
print(f"Warning: Short output detected: '{decoded}'")
else:
print(f"Output preview: '{decoded[:50]}...'")
return decoded
except RuntimeError as e:
if "out of memory" in str(e):
torch.cuda.empty_cache()
raise RuntimeError(
"Out of VRAM during generation. Try:\n"
"1. Using 'Maximum Savings (CPU Only)' mode\n"
"2. Reducing max_new_tokens\n"
"3. Using a smaller model (e.g., 4B instead of 12B)"
) from e
error_detail = traceback.format_exc()
print(f"Detailed error during generation: {error_detail}")
raise RuntimeError(f"Error during generation: {str(e)}\n\nDetails: {error_detail}")
except Exception as e:
error_detail = traceback.format_exc()
print(f"Detailed error during generation: {error_detail}")
return f"Error during generation: {str(e)}\n\nDetails: {error_detail}"
class Gemma3Node:
def __init__(self):
self.predictor = None
self.current_model = None
self.current_memory_mode = None
@classmethod
def INPUT_TYPES(cls):
default_prompt = "Write a prompt to recreate every detail of the image and include slang in your vocabulary. Guidelines: Include art style, subject age, expression, pose, background, and camera angle. Start your response without any label or pre-statement or conversation."
return {
"required": {
"text_input": ("STRING", {
"multiline": True,
"default": default_prompt
}),
"model_name": (list(GEMMA3_MODELS.keys()),),
"memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()), {"default": "Balanced (GPU)"}),
"max_new_tokens": ("INT", {
"default": 256,
"min": 1,
"max": 4096,
"step": 1 # Ensure whole numbers
}),
"temperature": ("FLOAT", {
"default": 0.8, # Gemma recommended
"min": 0.1,
"max": 2.0,
"step": 0.1
}),
"top_p": ("FLOAT", {
"default": 0.95, # Gemma recommended
"min": 0.1,
"max": 1.0,
"step": 0.05
})
},
"optional": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "VLM Nodes/Gemma-3"
def generate(self, text_input, model_name, memory_mode, max_new_tokens=512, temperature=0.8, top_p=0.95, image=None):
try:
# Handle the case when memory_mode is passed as an integer or other non-string type
if not isinstance(memory_mode, str):
print(f"Warning: memory_mode was passed as a non-string value ({type(memory_mode)}: {memory_mode}). Converting to default 'Balanced (GPU)'")
memory_mode = "Balanced (GPU)"
# Validate that memory_mode is one of our expected values
memory_mode_options = list(MEMORY_EFFICIENT_CONFIGS.keys())
if memory_mode not in memory_mode_options:
print(f"Warning: Invalid memory_mode '{memory_mode}'. Valid options are {memory_mode_options}. Using default 'Balanced (GPU)'")
memory_mode = "Balanced (GPU)"
print(f"Using memory_mode: {memory_mode}")
# Initialize or update predictor if model or memory mode changed
if (self.predictor is None or self.current_model != model_name or
self.current_memory_mode != memory_mode):
# Clean up old model
if self.predictor is not None:
del self.predictor.model
del self.predictor.processor
torch.cuda.empty_cache()
try:
print(f"Initializing {model_name} with memory mode {memory_mode}...")
self.predictor = Gemma3Predictor(model_name, memory_mode)
self.current_model = model_name
self.current_memory_mode = memory_mode
except Exception as e:
error_detail = traceback.format_exc()
print(f"Detailed initialization error: {error_detail}")
return (f"Error initializing model: {str(e)}\n\nDetails: {error_detail}",)
# Generate predictions
print(f"Generating predictions with {model_name}...")
result = self.predictor.generate_predictions(
text_input,
image=image,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p
)
print(f"Generation completed successfully")
return (result,)
except Exception as e:
error_detail = traceback.format_exc()
print(f"Detailed error in generate method: {error_detail}")
return (f"Error during generation: {str(e)}\n\nDetails: {error_detail}",)
# Add the two-node implementation from ComfyUI_Gemma3
class Gemma3ModelLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_id": (GOOGLE_GEMMA3_MODELS, {"default": "google/gemma-3-4b-it"}),
"load_local_model": ("BOOLEAN", {"default": False}),
},
"optional": {
"local_gemma3_model_path": ("STRING", {"default": "google/gemma-3-4b-it"}),
}
}
RETURN_TYPES = ("MODEL", "PROCESSOR")
RETURN_NAMES = ("model", "processor")
FUNCTION = "load_model"
CATEGORY = "VLM Nodes/Gemma-3"
def load_model(self, model_id, load_local_model, *args, **kwargs):
device = mm.get_torch_device()
if load_local_model:
# If loading a local model, directly use the path provided by the user
model_id = kwargs.get("local_gemma3_model_path", model_id)
else:
# If loading a Hugging Face model, download to ComfyUI's model directory
gemma_dir = os.path.join(folder_paths.models_dir, "Gemma3")
os.makedirs(gemma_dir, exist_ok=True)
# Download model to specified directory
print(f"Loading model from {model_id}...")
try:
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, cache_dir=gemma_dir, device_map="auto", trust_remote_code=True
).eval().to(device)
processor = AutoProcessor.from_pretrained(
model_id, cache_dir=gemma_dir, trust_remote_code=True
)
return (model, processor)
except Exception as e:
error_detail = traceback.format_exc()
print(f"Error loading model {model_id}: {str(e)}\n{error_detail}")
raise RuntimeError(f"Failed to load {model_id}: {str(e)}")
# Load local model
try:
print(f"Loading local model from {model_id}...")
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", trust_remote_code=True
).eval().to(device)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
return (model, processor)
except Exception as e:
error_detail = traceback.format_exc()
print(f"Error loading local model {model_id}: {str(e)}\n{error_detail}")
raise RuntimeError(f"Failed to load local model {model_id}: {str(e)}")
class ApplyGemma3:
@classmethod
def INPUT_TYPES(s):
default_prompt = "Write a prompt to recreate every detail of the image and include slang in your vocabulary. Guidelines: Include art style, subject age, expression, pose, background, and camera angle. Start your response without any label or pre-statement or conversation."
return {
"required": {
"model": ("MODEL",),
"processor": ("PROCESSOR",),
"prompt": ("STRING", {"default": default_prompt, "multiline": True}),
"max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 1000, "step": 1}),
},
"optional": {
"image": ("IMAGE",), # Image input as optional
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("description",)
FUNCTION = "apply_gemma3"
CATEGORY = "VLM Nodes/Gemma-3"
def apply_gemma3(self, model, processor, prompt, max_new_tokens, image=None):
# Add token limit reminder to the prompt
token_limit = max(10, max_new_tokens - 10) # Ensure we don't go negative
prompt_with_limit = f"{prompt}\nYou must stay under {token_limit} tokens."
print(f"Added token limit of {token_limit} to prompt")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}]
}
]
# If there is an image input, add the image and prompt to the messages
if image is not None:
image_pil = tensor2pil(image) # Convert ComfyUI's IMAGE format to PIL.Image
messages.append({
"role": "user",
"content": [
{"type": "image", "image": image_pil},
{"type": "text", "text": prompt_with_limit}
]
})
else:
# If there is no image input, only use text prompt
messages.append({
"role": "user",
"content": [{"type": "text", "text": prompt_with_limit}]
})
# Process input
inputs = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
print(f"Input length: {input_len} tokens")
print(f"Generating up to {max_new_tokens} new tokens...")
# Generate text
with torch.inference_mode():
generation = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
generation = generation[0][input_len:]
print(f"Generated {len(generation)} tokens")
decoded = processor.decode(generation, skip_special_tokens=True)
print(f"Output length: {len(decoded)} characters")
if len(decoded) < 50:
print(f"Warning: Short output detected: '{decoded}'")
else:
print(f"Output preview: '{decoded[:50]}...'")
return (decoded,)
# Register all nodes
NODE_CLASS_MAPPINGS = {
"Gemma3Node": Gemma3Node,
"Gemma3ModelLoader": Gemma3ModelLoader,
"ApplyGemma3": ApplyGemma3
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Gemma3Node": "Gemma-3 Model (All-in-One)",
"Gemma3ModelLoader": "Gemma-3 Model Loader",
"ApplyGemma3": "Apply Gemma-3"
}
