|
6 | 6 | # Stability AI Non-Commercial Research Community License Agreement, dated November 28, 2023. |
7 | 7 | # For more information, see https://stability.ai/use-policy. |
8 | 8 |
|
9 | | -from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline |
| 9 | +from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline, StableCascadeUNet |
10 | 10 | import gradio as gr |
11 | 11 | import json |
12 | 12 | import os |
|
24 | 24 | def load_model(model_name): |
25 | 25 | # Load model from disk every time it's needed |
26 | 26 | if model_name == "prior": |
27 | | - model = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=dtype).to(device) |
| 27 | + model = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant="bf16", torch_dtype=dtype, use_safetensors=True).to(device) |
28 | 28 | elif model_name == "decoder": |
29 | | - model = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.float16).to(device) |
| 29 | + model = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=dtype, use_safetensors=True).to(device) |
30 | 30 | else: |
31 | 31 | raise ValueError(f"Unknown model name: {model_name}") |
32 | 32 | return model |
@@ -79,26 +79,23 @@ def generate_images(prompt, height, width, negative_prompt, guidance_scale, num_ |
79 | 79 | num_images_per_prompt=int(num_images_per_prompt), |
80 | 80 | generator=generator, |
81 | 81 | ) |
82 | | - del prior # Explicitly delete the model to help with memory management |
83 | | - torch.cuda.empty_cache() # Clear the CUDA cache to free up unused memory |
84 | 82 |
|
85 | 83 | # Load, use, and discard the decoder model |
86 | 84 | decoder = load_model("decoder") |
87 | 85 | decoder.enable_model_cpu_offload() |
88 | 86 | decoder_output = decoder( |
89 | | - image_embeddings=prior_output.image_embeddings.to(torch.float16), |
| 87 | + image_embeddings=prior_output.image_embeddings.to(dtype), |
90 | 88 | prompt=cleaned_prompt, |
91 | 89 | negative_prompt=negative_prompt, |
92 | | - guidance_scale=0.0, |
| 90 | + guidance_scale=1.9, # Guidance scale is enabled by setting guidance_scale > 1 |
93 | 91 | num_inference_steps=calculated_steps_decoder, |
94 | 92 | output_type="pil", |
95 | 93 | generator=generator, |
96 | 94 | ).images |
97 | | - del decoder # Explicitly delete the model to help with memory management |
98 | | - torch.cuda.empty_cache() # Clear the CUDA cache to free up unused memory |
99 | | - |
| 95 | + |
100 | 96 | metadata_embedded = { |
101 | 97 | "parameters": "Stable Cascade", |
| 98 | + "scheduler": "DDPMWuerstchenScheduler", |
102 | 99 | "prompt": cleaned_prompt, |
103 | 100 | "negative_prompt": negative_prompt, |
104 | 101 | "width": int(width), |
@@ -190,8 +187,8 @@ def configure_ui(): |
190 | 187 | height = gr.Slider(minimum=512, maximum=2048, step=1, value=1024, label="Image Height") |
191 | 188 | with gr.Column(): |
192 | 189 | # components in central column |
193 | | - num_inference_steps = gr.Slider(minimum=1, maximum=150, step=1, value=30, label="Steps") |
194 | | - num_images_per_prompt = gr.Number(label="Number of Images per Prompt (Currently, the system can only generate one image at a time. Please leave the 'Images per Prompt' setting at 1 until this issue is fixed.)", value=1) |
| 190 | + num_inference_steps = gr.Slider(minimum=1, maximum=150, step=1, value=54, label="Steps") |
| 191 | + num_images_per_prompt = gr.Number(label="Number of Images per Prompt", value=2) |
195 | 192 | with gr.Column(): |
196 | 193 | # components in right column |
197 | 194 | guidance_scale = gr.Slider(minimum=1, maximum=20, step=0.5, value=4.0, label="Guidance Scale") |
|
0 commit comments