-
Notifications
You must be signed in to change notification settings - Fork 30
Description
I’m trying to fine-tune the stylebooth tuner model using SwiftLoRA with a relatively high rank (R=256, Alpha=256), a learning rate of 1e-4, Image size [1024, 1024], and a fixed batch size of 2.
Based on this yaml, I changed it like this.
ENV:
BACKEND: nccl
SOLVER:
NAME: LatentDiffusionSolver
RESUME_FROM:
LOAD_MODEL_ONLY: True
USE_FSDP: False
SHARDING_STRATEGY:
USE_AMP: True
DTYPE: float16
CHANNELS_LAST: True
MAX_STEPS: 210000
MAX_EPOCHS: -1
NUM_FOLDS: 1
ACCU_STEP: 1
EVAL_INTERVAL: 100
#
WORK_DIR: ./scepter/cache/save_data/edit_1024_lora_r256_i
LOG_FILE: std_log_i.txt
#
FILE_SYSTEM:
NAME: "LocalFs"
TEMP_DIR: "./scepter/cache/cache_data"
#
TUNER:
-
NAME: SwiftLoRA
R: 256
LORA_ALPHA: 256
LORA_DROPOUT: 0.0
BIAS: "none"
TARGET_MODULES: model.*(to_q|to_k|to_v|to_out.0|net.0.proj|net.2)$
#
MODEL:
NAME: LatentDiffusionEdit
PARAMETERIZATION: eps
TIMESTEPS: 1000
MIN_SNR_GAMMA:
ZERO_TERMINAL_SNR: False
#PRETRAINED_MODEL: ms://iic/stylebooth@models/stylebooth-tb-5000-0.bin
PRETRAINED_MODEL: "./scepter/weights/stylebooth-tb-5000-0.bin"
IGNORE_KEYS: [ ]
CONCAT_NO_SCALE_FACTOR: True
SCALE_FACTOR: 0.18215
SIZE_FACTOR: 8
# DEFAULT_N_PROMPT: 'lowres, error, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature'
DEFAULT_N_PROMPT:
SCHEDULE_ARGS:
"NAME": "scaled_linear"
"BETA_MIN": 0.00085
"BETA_MAX": 0.012
USE_EMA: False
#
DIFFUSION_MODEL:
NAME: DiffusionUNet
IN_CHANNELS: 8
OUT_CHANNELS: 4
MODEL_CHANNELS: 320
NUM_HEADS: 8
NUM_RES_BLOCKS: 2
ATTENTION_RESOLUTIONS: [ 4, 2, 1 ]
CHANNEL_MULT: [ 1, 2, 4, 4 ]
CONV_RESAMPLE: True
DIMS: 2
USE_CHECKPOINT: False
USE_SCALE_SHIFT_NORM: False
RESBLOCK_UPDOWN: False
USE_SPATIAL_TRANSFORMER: True
TRANSFORMER_DEPTH: 1
CONTEXT_DIM: 768
DISABLE_MIDDLE_SELF_ATTN: False
USE_LINEAR_IN_TRANSFORMER: False
PRETRAINED_MODEL:
IGNORE_KEYS: []
#
FIRST_STAGE_MODEL:
NAME: AutoencoderKL
EMBED_DIM: 4
PRETRAINED_MODEL:
IGNORE_KEYS: []
BATCH_SIZE: 4
#
ENCODER:
NAME: Encoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DOUBLE_Z: True
DROPOUT: 0.0
RESAMP_WITH_CONV: True
#
DECODER:
NAME: Decoder
CH: 128
OUT_CH: 3
NUM_RES_BLOCKS: 2
IN_CHANNELS: 3
ATTN_RESOLUTIONS: [ ]
CH_MULT: [ 1, 2, 4, 4 ]
Z_CHANNELS: 4
DROPOUT: 0.0
RESAMP_WITH_CONV: True
GIVE_PRE_END: False
TANH_OUT: False
#
TOKENIZER:
NAME: ClipTokenizer
PRETRAINED_PATH: "./scepter/weights/clip-vit-large-patch14"
LENGTH: 77
CLEAN: True
#
COND_STAGE_MODEL:
NAME: FrozenCLIPEmbedder
FREEZE: True
LAYER: last
PRETRAINED_MODEL: "./scepter/weights/clip-vit-large-patch14"
#
LOSS:
NAME: ReconstructLoss
LOSS_TYPE: l2
#
SAMPLE_ARGS:
SAMPLER: ddim
SAMPLE_STEPS: 50
SEED: 2023
GUIDE_SCALE: #7.5
image: 1.5
text: 7.5
GUIDE_RESCALE: 0.5
DISCRETIZATION: trailing
IMAGE_SIZE: [1024, 1024]
RUN_TRAIN_N: False
#
OPTIMIZER:
NAME: AdamW
LEARNING_RATE: 1e-4
BETAS: [ 0.9, 0.999 ]
EPS: 1e-8
WEIGHT_DECAY: 1e-2
AMSGRAD: True
#
TRAIN_DATA:
NAME: ImageTextPairFolderDataset
MODE: train
DATA_FOLDER: ./scepter/cache/datasets/i_pair
PROMPT_PREFIX: ""
REPLACE_STYLE: False
PIN_MEMORY: True
BATCH_SIZE: 2
NUM_WORKERS: 4
SAMPLER:
NAME: LoopSampler
TRANSFORMS:
- NAME: LoadImageFromFileList
FILE_KEYS: ['img_path', 'src_path']
RGB_ORDER: RGB
BACKEND: pillow
- NAME: FlexibleResize
INTERPOLATION: bilinear
SIZE: [ 1024, 1024 ]
INPUT_KEY: [ 'img', 'src' ]
OUTPUT_KEY: [ 'img', 'src' ]
BACKEND: pillow
- NAME: FlexibleCenterCrop
SIZE: [ 1024, 1024 ]
INPUT_KEY: [ 'img', 'src' ]
OUTPUT_KEY: [ 'img', 'src' ]
BACKEND: pillow
- NAME: ImageToTensor
INPUT_KEY: [ 'img', 'src' ]
OUTPUT_KEY: [ 'img', 'src' ]
BACKEND: pillow
- NAME: Normalize
MEAN: [ 0.5, 0.5, 0.5 ]
STD: [ 0.5, 0.5, 0.5 ]
INPUT_KEY: [ 'img', 'src' ]
OUTPUT_KEY: [ 'image', 'condition_cat' ]
BACKEND: torchvision
- NAME: Select
KEYS: [ 'image', 'condition_cat', 'prompt' ]
META_KEYS: [ 'data_key' ]
#
TRAIN_HOOKS:
-
NAME: BackwardHook
PRIORITY: 0
-
NAME: LogHook
LOG_INTERVAL: 50
-
NAME: CheckpointHook
INTERVAL: 10000
PUSH_TO_HUB : False
-
NAME: ProbeDataHook
PROB_INTERVAL: 10000
EVAL_DATA:
NAME: Text2ImageDataset
MODE: eval
PROMPT_FILE:
PROMPT_DATA: [ "Transfer the lighting conditions of the reference image to this image#;#scepter/cache/datasets/i_pair/input/Image000.png" ]
IMAGE_SIZE: [ 1024, 1024 ]
FIELDS: [ "prompt", "src_path" ]
DELIMITER: '#;#'
PROMPT_PREFIX: ''
PIN_MEMORY: True
BATCH_SIZE: 1
NUM_WORKERS: 4
TRANSFORMS:
- NAME: LoadImageFromFileList
FILE_KEYS: [ 'src_path' ]
RGB_ORDER: RGB
BACKEND: pillow
- NAME: FlexibleResize
INTERPOLATION: bilinear
SIZE: [ 1024, 1024 ]
INPUT_KEY: [ 'src' ]
OUTPUT_KEY: [ 'src' ]
BACKEND: pillow
- NAME: FlexibleCenterCrop
SIZE: [ 1024, 1024 ]
INPUT_KEY: [ 'src' ]
OUTPUT_KEY: [ 'src' ]
BACKEND: pillow
- NAME: ImageToTensor
INPUT_KEY: [ 'src' ]
OUTPUT_KEY: [ 'src' ]
BACKEND: pillow
- NAME: Normalize
MEAN: [ 0.5, 0.5, 0.5 ]
STD: [ 0.5, 0.5, 0.5 ]
INPUT_KEY: [ 'src' ]
OUTPUT_KEY: [ 'condition_cat' ]
BACKEND: torchvision
- NAME: Select
KEYS: [ 'condition_cat', 'prompt' ]
META_KEYS: [ 'image_size' ]
EVAL_HOOKS:
-
NAME: ProbeDataHook
PROB_INTERVAL: 100
SAVE_LAST: True
SAVE_NAME_PREFIX: 'step'
SAVE_PROBE_PREFIX: 'image'
I fine-tune it on a single GPU, but the loss suddenly increases during training and the generated images turn into pure noise.
scepter [INFO] 2025-07-09 14:21:47,031 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [50/210000], data_time: 0.6301(0.6301), time: 1.1624(1.1624), loss: 0.0817(0.0817), throughput: 312590/day, all_throughput: 100, pg0_lr: 0.000100, scale: 1.000000, [1mins 9secs 0.02%(3days 9hours 13mins 22secs)]
scepter [INFO] 2025-07-09 14:22:14,665 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [100/210000], data_time: 0.009399(0.3198), time: 0.5529(0.8577), loss: 0.0868(0.0843), throughput: 312717/day, all_throughput: 200, pg0_lr: 0.000100, scale: 1.000000, [1mins 37secs 0.05%(2days 8hours 42mins 50secs)]
scepter [INFO] 2025-07-09 14:22:55,082 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [150/210000], data_time: 0.2801(0.3065), time: 0.8083(0.8412), loss: 0.1073(0.0920), throughput: 313775/day, all_throughput: 300, pg0_lr: 0.000100, scale: 1.000000, [2mins 17secs 0.07%(2days 5hours 30mins 23secs)]
scepter [INFO] 2025-07-09 14:23:21,633 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [200/210000], data_time: 0.0122(0.2329), time: 0.5310(0.7637), loss: 0.0968(0.0932), throughput: 316758/day, all_throughput: 400, pg0_lr: 0.000100, scale: 1.000000, [2mins 44secs 0.10%(1days 23hours 51mins 25secs)]
scepter [INFO] 2025-07-09 14:24:00,463 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [250/210000], data_time: 0.2562(0.2376), time: 0.7766(0.7663), loss: 0.1078(0.0961), throughput: 317061/day, all_throughput: 500, pg0_lr: 0.000100, scale: 1.000000, [3mins 23secs 0.12%(1days 23hours 19mins 33secs)]
scepter [INFO] 2025-07-09 14:24:26,817 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [300/210000], data_time: 0.0118(0.2000), time: 0.5268(0.7263), loss: 0.0811(0.0936), throughput: 318919/day, all_throughput: 600, pg0_lr: 0.000100, scale: 1.000000, [3mins 49secs 0.14%(1days 20hours 32mins 46secs)]
scepter [INFO] 2025-07-09 14:25:04,047 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [350/210000], data_time: 0.2275(0.2039), time: 0.7448(0.7290), loss: 0.0910(0.0932), throughput: 319195/day, all_throughput: 700, pg0_lr: 0.000100, scale: 1.000000, [4mins 26secs 0.17%(1days 20hours 22mins 4secs)]
scepter [INFO] 2025-07-09 14:25:31,323 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [400/210000], data_time: 0.0103(0.1797), time: 0.5456(0.7061), loss: 0.1097(0.0953), throughput: 319086/day, all_throughput: 800, pg0_lr: 0.000100, scale: 1.000000, [4mins 53secs 0.19%(1days 18hours 46mins 58secs)]
scepter [INFO] 2025-07-09 14:26:11,375 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [450/210000], data_time: 0.2766(0.1905), time: 0.8009(0.7166), loss: 0.1164(0.0976), throughput: 318814/day, all_throughput: 900, pg0_lr: 0.000100, scale: 1.000000, [5mins 33secs 0.21%(1days 19hours 12mins 3secs)]
scepter [INFO] 2025-07-09 14:26:37,859 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [500/210000], data_time: 0.0120(0.1726), time: 0.5297(0.6979), loss: 0.1375(0.1016), throughput: 319571/day, all_throughput: 1000, pg0_lr: 0.000100, scale: 1.000000, [6mins 0secs 0.24%(1days 17hours 57mins 14secs)]
scepter [INFO] 2025-07-09 14:27:20,358 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [550/210000], data_time: 0.3287(0.1868), time: 0.8498(0.7117), loss: 0.2402(0.1142), throughput: 319437/day, all_throughput: 1100, pg0_lr: 0.000100, scale: 1.000000, [6mins 42secs 0.26%(1days 18hours 37mins 35secs)]
scepter [INFO] 2025-07-09 14:27:46,660 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [600/210000], data_time: 0.0126(0.1723), time: 0.5263(0.6963), loss: 0.5065(0.1469), throughput: 320198/day, all_throughput: 1200, pg0_lr: 0.000100, scale: 1.000000, [7mins 9secs 0.29%(1days 17hours 36mins 53secs)]
- Are there any recommended tricks for stable training with high LoRA rank?
- Is there any recommended gradient clipping or norm scaling I should add?
- Should I adjust any part of the UNet config for large resolution + large LoRA rank?
Any advice would be appreciated!
