Skip to content

Gemma 3 supoort #148

@LoFiApostasy

Description

@LoFiApostasy

@gokayfem I would do this as a pull request but Im a hobbyist and have never used that function so ill do it here.

I made a Gemma 3 node using your framework which works great. The large 24b model outputs amazing quality but the 4b is enough, Ive included uncensored versions of all 3 model sizes..

Image

This offers some extra features like smart memory model offloading (balanced memory mode) as well as passing the max token limit to the prompt so the vlm knows its length limits avoiding long prompts being cut off.

I installed "git+https://github.com/huggingface/[email protected]" which seems to be a requirement but haven't tried without it.

Anyway i hope this makes its way into your awesome project, thanks so much for making your wonderful nodes. Feel free to close this feature request anytime.

gemma3.py

import torch
import psutil
import os
import numpy as np
from PIL import Image
from pathlib import Path
import folder_paths
import comfy.model_management as mm
import traceback
from transformers import AutoProcessor, Gemma3ForConditionalGeneration

# Define the directory for saving Gemma-3 files
files_for_gemma3 = Path(folder_paths.folder_names_and_paths["LLavacheckpoints"][0][0]) / "files_for_gemma3"
files_for_gemma3.mkdir(parents=True, exist_ok=True)

# Model VRAM requirements (approximate, in GB)
MODEL_VRAM_REQUIREMENTS = {
    "Gemma-3-4B": 6,
    "Gemma-3-12B": 14,
    "Gemma-3-27B": 28,  # Added the largest model with its VRAM requirement
}

GEMMA3_MODELS = {
    "Gemma-3-4B": "mlabonne/gemma-3-4b-it-abliterated",
    "Gemma-3-12B": "mlabonne/gemma-3-12b-it-abliterated",
    "Gemma-3-27B": "mlabonne/gemma-3-27b-it-abliterated",  # Added the largest model
}

# Add Google's official model IDs
GOOGLE_GEMMA3_MODELS = [
    "google/gemma-3-27b-it",
    "google/gemma-3-12b-it", 
    "google/gemma-3-1b-it", 
    "google/gemma-3-4b-it"
]

# Define memory efficiency configurations
MEMORY_EFFICIENT_CONFIGS = {
    "Balanced (GPU)": {
        "device_map": "auto",
        "torch_dtype": torch.bfloat16
    },
    "Maximum Performance (Full GPU)": {
        "device_map": "cuda",
        "torch_dtype": torch.bfloat16
    },
    "Maximum Savings (CPU Only)": {
        "device_map": "cpu",
        "torch_dtype": torch.float32
    }
}

def tensor2pil(image):
    return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))

class SystemResources:
    @staticmethod
    def get_available_memory():
        """Get available system memory in GB"""
        return psutil.virtual_memory().available / (1024 * 1024 * 1024)

    @staticmethod
    def get_available_vram():
        """Get available VRAM in GB"""
        if not torch.cuda.is_available():
            return 0
        
        try:
            torch.cuda.empty_cache()  # Clear unused cached memory
            return torch.cuda.get_device_properties(0).total_memory / (1024 * 1024 * 1024)
        except:
            return 0

    @staticmethod
    def check_resources(model_name, memory_mode):
        """Check if system has enough resources for the model"""
        required_vram = MODEL_VRAM_REQUIREMENTS.get(model_name, 0)
        config = MEMORY_EFFICIENT_CONFIGS[memory_mode]
        
        # Adjust VRAM requirements based on memory mode
        if config["device_map"] == "cpu":
            required_vram = 0  # CPU only mode
        elif config["device_map"] == "auto":
            # Auto device map may use less VRAM by offloading some layers
            required_vram = required_vram * 0.8
            
        available_vram = SystemResources.get_available_vram()
        available_memory = SystemResources.get_available_memory()
        
        # Need at least 2GB system memory buffer
        required_system_memory = required_vram + 2
        
        error_messages = []
        if available_vram < required_vram and config["device_map"] != "cpu":
            error_messages.append(
                f"Insufficient VRAM: Model {model_name} requires approximately {required_vram:.1f}GB VRAM, "
                f"but only {available_vram:.1f}GB available. "
                "Try using 'Maximum Savings (CPU Only)' mode."
            )
        
        if available_memory < required_system_memory:
            error_messages.append(
                f"Insufficient system memory: Need at least {required_system_memory:.1f}GB, "
                f"but only {available_memory:.1f}GB available"
            )
            
        return error_messages

class Gemma3Predictor:
    def __init__(self, model_name, memory_mode="Balanced (GPU)"):
        # Check system resources
        error_messages = SystemResources.check_resources(model_name, memory_mode)
        if error_messages:
            raise RuntimeError("\n".join(error_messages))
            
        # Get model configuration
        model_id = GEMMA3_MODELS[model_name]
        config = MEMORY_EFFICIENT_CONFIGS[memory_mode]
        
        try:
            # Use the ComfyUI model directory for caching
            cache_dir = os.path.join(folder_paths.models_dir, "Gemma3")
            os.makedirs(cache_dir, exist_ok=True)
            
            print(f"Loading model {model_name} from {model_id}...")
            print(f"Using memory mode: {memory_mode}")
            print(f"Cache directory: {cache_dir}")
            
            # Load processor and model
            self.processor = AutoProcessor.from_pretrained(
                model_id,
                cache_dir=cache_dir,
                trust_remote_code=True
            )
            
            self.model = Gemma3ForConditionalGeneration.from_pretrained(
                model_id,
                cache_dir=cache_dir,
                device_map=config["device_map"],
                torch_dtype=config["torch_dtype"],
                trust_remote_code=True
            ).eval()
            
            print(f"Successfully loaded {model_name} model")
        except Exception as e:
            error_detail = traceback.format_exc()
            print(f"Detailed error: {error_detail}")
            raise RuntimeError(f"Failed to load {model_name} model: {str(e)}\n\nDetails: {error_detail}")

    def generate_predictions(self, prompt, image=None, max_new_tokens=512, temperature=0.7, top_p=0.9):
        """Generate text predictions using the Gemma model"""
        try:
            # Add token limit reminder to the prompt
            token_limit = max(10, max_new_tokens - 10)  # Ensure we don't go negative
            prompt_with_limit = f"{prompt}\nYou must stay under {token_limit} tokens."
            print(f"Added token limit of {token_limit} to prompt")
            
            # Create the messages format for the model
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": "You are a helpful assistant."}]
                }
            ]

            # Add user message with image if provided
            if image is not None:
                image_pil = tensor2pil(image)
                messages.append({
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image_pil},
                        {"type": "text", "text": prompt_with_limit}
                    ]
                })
            else:
                # Text-only prompt
                messages.append({
                    "role": "user",
                    "content": [{"type": "text", "text": prompt_with_limit}]
                })

            # Process the input - explicitly setting dtype to bfloat16 like in the working implementation
            inputs = self.processor.apply_chat_template(
                messages, 
                add_generation_prompt=True, 
                tokenize=True,
                return_dict=True, 
                return_tensors="pt"
            ).to(self.model.device, dtype=torch.bfloat16)  # Explicitly set the dtype

            input_len = inputs["input_ids"].shape[-1]
            
            print(f"Input length: {input_len} tokens")
            print(f"Generating up to {max_new_tokens} new tokens...")
            
            # Generate text - Use the simpler approach like in the working implementation
            with torch.inference_mode():
                # Simplify generation parameters to match the working implementation
                generation = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False  # Set to False like in the working implementation
                )
                
                # Extract the new tokens
                generation = generation[0][input_len:]
                
                print(f"Generated {len(generation)} tokens")

            # Decode and return the result
            decoded = self.processor.decode(generation, skip_special_tokens=True)
            
            # Debug the output to see if it's being truncated
            print(f"Output length: {len(decoded)} characters")
            if len(decoded) < 50:
                print(f"Warning: Short output detected: '{decoded}'")
            else:
                print(f"Output preview: '{decoded[:50]}...'")
                
            return decoded
            
        except RuntimeError as e:
            if "out of memory" in str(e):
                torch.cuda.empty_cache()
                raise RuntimeError(
                    "Out of VRAM during generation. Try:\n"
                    "1. Using 'Maximum Savings (CPU Only)' mode\n"
                    "2. Reducing max_new_tokens\n"
                    "3. Using a smaller model (e.g., 4B instead of 12B)"
                ) from e
            error_detail = traceback.format_exc()
            print(f"Detailed error during generation: {error_detail}")
            raise RuntimeError(f"Error during generation: {str(e)}\n\nDetails: {error_detail}")
        except Exception as e:
            error_detail = traceback.format_exc()
            print(f"Detailed error during generation: {error_detail}")
            return f"Error during generation: {str(e)}\n\nDetails: {error_detail}"

class Gemma3Node:
    def __init__(self):
        self.predictor = None
        self.current_model = None
        self.current_memory_mode = None

    @classmethod
    def INPUT_TYPES(cls):
        default_prompt = "Write a prompt to recreate every detail of the image and include slang in your vocabulary. Guidelines: Include art style, subject age, expression, pose, background, and camera angle. Start your response without any label or pre-statement or conversation."

        return {
            "required": {
                "text_input": ("STRING", {
                    "multiline": True,
                    "default": default_prompt
                }),
                "model_name": (list(GEMMA3_MODELS.keys()),),
                "memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()), {"default": "Balanced (GPU)"}),
                "max_new_tokens": ("INT", {
                    "default": 256,
                    "min": 1,
                    "max": 4096,
                    "step": 1  # Ensure whole numbers
                }),
                "temperature": ("FLOAT", {
                    "default": 0.8,  # Gemma recommended
                    "min": 0.1,
                    "max": 2.0,
                    "step": 0.1
                }),
                "top_p": ("FLOAT", {
                    "default": 0.95,  # Gemma recommended
                    "min": 0.1,
                    "max": 1.0,
                    "step": 0.05
                })
            },
            "optional": {
                "image": ("IMAGE",),
            }
        }

    RETURN_TYPES = ("STRING",)
    FUNCTION = "generate"
    CATEGORY = "VLM Nodes/Gemma-3"

    def generate(self, text_input, model_name, memory_mode, max_new_tokens=512, temperature=0.8, top_p=0.95, image=None):
        try:
            # Handle the case when memory_mode is passed as an integer or other non-string type
            if not isinstance(memory_mode, str):
                print(f"Warning: memory_mode was passed as a non-string value ({type(memory_mode)}: {memory_mode}). Converting to default 'Balanced (GPU)'")
                memory_mode = "Balanced (GPU)"
            
            # Validate that memory_mode is one of our expected values
            memory_mode_options = list(MEMORY_EFFICIENT_CONFIGS.keys())
            if memory_mode not in memory_mode_options:
                print(f"Warning: Invalid memory_mode '{memory_mode}'. Valid options are {memory_mode_options}. Using default 'Balanced (GPU)'")
                memory_mode = "Balanced (GPU)"
                
            print(f"Using memory_mode: {memory_mode}")
            
            # Initialize or update predictor if model or memory mode changed
            if (self.predictor is None or self.current_model != model_name or 
                self.current_memory_mode != memory_mode):
                
                # Clean up old model
                if self.predictor is not None:
                    del self.predictor.model
                    del self.predictor.processor
                    torch.cuda.empty_cache()
                    
                try:
                    print(f"Initializing {model_name} with memory mode {memory_mode}...")
                    self.predictor = Gemma3Predictor(model_name, memory_mode)
                    self.current_model = model_name
                    self.current_memory_mode = memory_mode
                except Exception as e:
                    error_detail = traceback.format_exc()
                    print(f"Detailed initialization error: {error_detail}")
                    return (f"Error initializing model: {str(e)}\n\nDetails: {error_detail}",)

            # Generate predictions
            print(f"Generating predictions with {model_name}...")
            result = self.predictor.generate_predictions(
                text_input,
                image=image,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p
            )
            
            print(f"Generation completed successfully")
            return (result,)
            
        except Exception as e:
            error_detail = traceback.format_exc()
            print(f"Detailed error in generate method: {error_detail}")
            return (f"Error during generation: {str(e)}\n\nDetails: {error_detail}",)

# Add the two-node implementation from ComfyUI_Gemma3
class Gemma3ModelLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model_id": (GOOGLE_GEMMA3_MODELS, {"default": "google/gemma-3-4b-it"}),
                "load_local_model": ("BOOLEAN", {"default": False}),
            },
            "optional": {
                "local_gemma3_model_path": ("STRING", {"default": "google/gemma-3-4b-it"}),
            }
        }

    RETURN_TYPES = ("MODEL", "PROCESSOR")
    RETURN_NAMES = ("model", "processor")
    FUNCTION = "load_model"
    CATEGORY = "VLM Nodes/Gemma-3"

    def load_model(self, model_id, load_local_model, *args, **kwargs):
        device = mm.get_torch_device()
        if load_local_model:
            # If loading a local model, directly use the path provided by the user
            model_id = kwargs.get("local_gemma3_model_path", model_id)
        else:
            # If loading a Hugging Face model, download to ComfyUI's model directory
            gemma_dir = os.path.join(folder_paths.models_dir, "Gemma3")
            os.makedirs(gemma_dir, exist_ok=True)

            # Download model to specified directory
            print(f"Loading model from {model_id}...")
            try:
                model = Gemma3ForConditionalGeneration.from_pretrained(
                    model_id, cache_dir=gemma_dir, device_map="auto", trust_remote_code=True
                ).eval().to(device)
                processor = AutoProcessor.from_pretrained(
                    model_id, cache_dir=gemma_dir, trust_remote_code=True
                )
                return (model, processor)
            except Exception as e:
                error_detail = traceback.format_exc()
                print(f"Error loading model {model_id}: {str(e)}\n{error_detail}")
                raise RuntimeError(f"Failed to load {model_id}: {str(e)}")

        # Load local model
        try:
            print(f"Loading local model from {model_id}...")
            model = Gemma3ForConditionalGeneration.from_pretrained(
                model_id, device_map="auto", trust_remote_code=True
            ).eval().to(device)
            processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
            return (model, processor)
        except Exception as e:
            error_detail = traceback.format_exc()
            print(f"Error loading local model {model_id}: {str(e)}\n{error_detail}")
            raise RuntimeError(f"Failed to load local model {model_id}: {str(e)}")


class ApplyGemma3:
    @classmethod
    def INPUT_TYPES(s):
        default_prompt = "Write a prompt to recreate every detail of the image and include slang in your vocabulary. Guidelines: Include art style, subject age, expression, pose, background, and camera angle. Start your response without any label or pre-statement or conversation."
        
        return {
            "required": {
                "model": ("MODEL",),
                "processor": ("PROCESSOR",),
                "prompt": ("STRING", {"default": default_prompt, "multiline": True}),
                "max_new_tokens": ("INT", {"default": 256, "min": 1, "max": 1000, "step": 1}),
            },
            "optional": {
                "image": ("IMAGE",),  # Image input as optional
            }
        }

    RETURN_TYPES = ("STRING",)
    RETURN_NAMES = ("description",)
    FUNCTION = "apply_gemma3"
    CATEGORY = "VLM Nodes/Gemma-3"

    def apply_gemma3(self, model, processor, prompt, max_new_tokens, image=None):
        # Add token limit reminder to the prompt
        token_limit = max(10, max_new_tokens - 10)  # Ensure we don't go negative
        prompt_with_limit = f"{prompt}\nYou must stay under {token_limit} tokens."
        print(f"Added token limit of {token_limit} to prompt")
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            }
        ]

        # If there is an image input, add the image and prompt to the messages
        if image is not None:
            image_pil = tensor2pil(image)  # Convert ComfyUI's IMAGE format to PIL.Image
            messages.append({
                "role": "user",
                "content": [
                    {"type": "image", "image": image_pil},
                    {"type": "text", "text": prompt_with_limit}
                ]
            })
        else:
            # If there is no image input, only use text prompt
            messages.append({
                "role": "user",
                "content": [{"type": "text", "text": prompt_with_limit}]
            })

        # Process input
        inputs = processor.apply_chat_template(
            messages, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        ).to(model.device, dtype=torch.bfloat16)

        input_len = inputs["input_ids"].shape[-1]
        print(f"Input length: {input_len} tokens")
        print(f"Generating up to {max_new_tokens} new tokens...")

        # Generate text
        with torch.inference_mode():
            generation = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
            generation = generation[0][input_len:]
            print(f"Generated {len(generation)} tokens")

        decoded = processor.decode(generation, skip_special_tokens=True)
        print(f"Output length: {len(decoded)} characters")
        
        if len(decoded) < 50:
            print(f"Warning: Short output detected: '{decoded}'")
        else:
            print(f"Output preview: '{decoded[:50]}...'")
            
        return (decoded,)

# Register all nodes
NODE_CLASS_MAPPINGS = {
    "Gemma3Node": Gemma3Node,
    "Gemma3ModelLoader": Gemma3ModelLoader,
    "ApplyGemma3": ApplyGemma3
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Gemma3Node": "Gemma-3 Model (All-in-One)",
    "Gemma3ModelLoader": "Gemma-3 Model Loader",
    "ApplyGemma3": "Apply Gemma-3"
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions