Skip to content

Training Stylebooth becomes unstable with high LoRA rank #115

@rukaeto

Description

@rukaeto

I’m trying to fine-tune the stylebooth tuner model using SwiftLoRA with a relatively high rank (R=256, Alpha=256), a learning rate of 1e-4, Image size [1024, 1024], and a fixed batch size of 2.

Based on this yaml, I changed it like this.

ENV:
  BACKEND: nccl
SOLVER:
  NAME: LatentDiffusionSolver
  RESUME_FROM:
  LOAD_MODEL_ONLY: True
  USE_FSDP: False
  SHARDING_STRATEGY:
  USE_AMP: True
  DTYPE: float16
  CHANNELS_LAST: True
  MAX_STEPS: 210000
  MAX_EPOCHS: -1
  NUM_FOLDS: 1
  ACCU_STEP: 1
  EVAL_INTERVAL: 100
  #
  WORK_DIR: ./scepter/cache/save_data/edit_1024_lora_r256_i
  LOG_FILE: std_log_i.txt
  #
  FILE_SYSTEM:
    NAME: "LocalFs"
    TEMP_DIR: "./scepter/cache/cache_data"
  #
  TUNER:
    -
      NAME: SwiftLoRA
      R: 256
      LORA_ALPHA: 256
      LORA_DROPOUT: 0.0
      BIAS: "none"
      TARGET_MODULES: model.*(to_q|to_k|to_v|to_out.0|net.0.proj|net.2)$
  #
  MODEL:
    NAME: LatentDiffusionEdit
    PARAMETERIZATION: eps
    TIMESTEPS: 1000
    MIN_SNR_GAMMA: 
    ZERO_TERMINAL_SNR: False
    #PRETRAINED_MODEL: ms://iic/stylebooth@models/stylebooth-tb-5000-0.bin
    PRETRAINED_MODEL: "./scepter/weights/stylebooth-tb-5000-0.bin"
    IGNORE_KEYS: [ ]
    CONCAT_NO_SCALE_FACTOR: True
    SCALE_FACTOR: 0.18215
    SIZE_FACTOR: 8
    # DEFAULT_N_PROMPT: 'lowres, error, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature'
    DEFAULT_N_PROMPT:
    SCHEDULE_ARGS:
      "NAME": "scaled_linear"
      "BETA_MIN": 0.00085
      "BETA_MAX": 0.012
    USE_EMA: False
    #
    DIFFUSION_MODEL:
      NAME: DiffusionUNet
      IN_CHANNELS: 8
      OUT_CHANNELS: 4
      MODEL_CHANNELS: 320
      NUM_HEADS: 8
      NUM_RES_BLOCKS: 2
      ATTENTION_RESOLUTIONS: [ 4, 2, 1 ]
      CHANNEL_MULT: [ 1, 2, 4, 4 ]
      CONV_RESAMPLE: True
      DIMS: 2
      USE_CHECKPOINT: False
      USE_SCALE_SHIFT_NORM: False
      RESBLOCK_UPDOWN: False
      USE_SPATIAL_TRANSFORMER: True
      TRANSFORMER_DEPTH: 1
      CONTEXT_DIM: 768
      DISABLE_MIDDLE_SELF_ATTN: False
      USE_LINEAR_IN_TRANSFORMER: False
      PRETRAINED_MODEL:
      IGNORE_KEYS: []
    #
    FIRST_STAGE_MODEL:
      NAME: AutoencoderKL
      EMBED_DIM: 4
      PRETRAINED_MODEL:
      IGNORE_KEYS: []
      BATCH_SIZE: 4
      #
      ENCODER:
        NAME: Encoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DOUBLE_Z: True
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
      #
      DECODER:
        NAME: Decoder
        CH: 128
        OUT_CH: 3
        NUM_RES_BLOCKS: 2
        IN_CHANNELS: 3
        ATTN_RESOLUTIONS: [ ]
        CH_MULT: [ 1, 2, 4, 4 ]
        Z_CHANNELS: 4
        DROPOUT: 0.0
        RESAMP_WITH_CONV: True
        GIVE_PRE_END: False
        TANH_OUT: False
    #
    TOKENIZER:
      NAME: ClipTokenizer
      PRETRAINED_PATH: "./scepter/weights/clip-vit-large-patch14"
      LENGTH: 77
      CLEAN: True
    #
    COND_STAGE_MODEL:
      NAME: FrozenCLIPEmbedder
      FREEZE: True
      LAYER: last
      PRETRAINED_MODEL: "./scepter/weights/clip-vit-large-patch14"

    #
    LOSS:
      NAME: ReconstructLoss
      LOSS_TYPE: l2
  #
  SAMPLE_ARGS:
    SAMPLER: ddim
    SAMPLE_STEPS: 50
    SEED: 2023
    GUIDE_SCALE: #7.5
      image: 1.5
      text: 7.5
    GUIDE_RESCALE: 0.5
    DISCRETIZATION: trailing
    IMAGE_SIZE: [1024, 1024]
    RUN_TRAIN_N: False
  #
  OPTIMIZER:
    NAME: AdamW
    LEARNING_RATE: 1e-4
    BETAS: [ 0.9, 0.999 ]
    EPS: 1e-8
    WEIGHT_DECAY: 1e-2
    AMSGRAD: True
  #
  TRAIN_DATA:
    NAME: ImageTextPairFolderDataset
    MODE: train
    DATA_FOLDER: ./scepter/cache/datasets/i_pair 
    PROMPT_PREFIX: ""
    REPLACE_STYLE: False
    PIN_MEMORY: True
    BATCH_SIZE: 2
    NUM_WORKERS: 4
    SAMPLER:
      NAME: LoopSampler
    TRANSFORMS:
      - NAME: LoadImageFromFileList
        FILE_KEYS: ['img_path', 'src_path']
        RGB_ORDER: RGB
        BACKEND: pillow
      - NAME: FlexibleResize
        INTERPOLATION: bilinear
        SIZE: [ 1024, 1024 ]
        INPUT_KEY: [ 'img', 'src' ]
        OUTPUT_KEY: [ 'img', 'src' ]
        BACKEND: pillow
      - NAME: FlexibleCenterCrop
        SIZE: [ 1024, 1024 ]
        INPUT_KEY: [ 'img', 'src' ]
        OUTPUT_KEY: [ 'img', 'src' ]
        BACKEND: pillow
      - NAME: ImageToTensor
        INPUT_KEY: [ 'img', 'src' ]
        OUTPUT_KEY: [ 'img', 'src' ]
        BACKEND: pillow
      - NAME: Normalize
        MEAN: [ 0.5,  0.5,  0.5 ]
        STD: [ 0.5,  0.5,  0.5 ]
        INPUT_KEY: [ 'img', 'src' ]
        OUTPUT_KEY: [ 'image', 'condition_cat' ]
        BACKEND: torchvision
      - NAME: Select
        KEYS: [ 'image', 'condition_cat', 'prompt' ]
        META_KEYS: [ 'data_key' ]
  #
  TRAIN_HOOKS:
    -
      NAME: BackwardHook
      PRIORITY: 0
    -
      NAME: LogHook
      LOG_INTERVAL: 50
    -
      NAME: CheckpointHook
      INTERVAL: 10000
      PUSH_TO_HUB : False
    -
      NAME: ProbeDataHook
      PROB_INTERVAL: 10000

  EVAL_DATA:
    NAME: Text2ImageDataset
    MODE: eval
    PROMPT_FILE:
    PROMPT_DATA: [ "Transfer the lighting conditions of the reference image to this image#;#scepter/cache/datasets/i_pair/input/Image000.png" ]
    IMAGE_SIZE: [ 1024, 1024 ]

    FIELDS: [ "prompt", "src_path" ]
    DELIMITER: '#;#'
    PROMPT_PREFIX: ''
    PIN_MEMORY: True
    BATCH_SIZE: 1
    NUM_WORKERS: 4
    TRANSFORMS:
      - NAME: LoadImageFromFileList
        FILE_KEYS: [ 'src_path' ]
        RGB_ORDER: RGB
        BACKEND: pillow
      - NAME: FlexibleResize
        INTERPOLATION: bilinear
        SIZE: [ 1024, 1024 ]
        INPUT_KEY: [ 'src' ]
        OUTPUT_KEY: [ 'src' ]
        BACKEND: pillow
      - NAME: FlexibleCenterCrop
        SIZE: [ 1024, 1024 ]
        INPUT_KEY: [ 'src' ]
        OUTPUT_KEY: [ 'src' ]
        BACKEND: pillow
      - NAME: ImageToTensor
        INPUT_KEY: [ 'src' ]
        OUTPUT_KEY: [ 'src' ]
        BACKEND: pillow
      - NAME: Normalize
        MEAN: [ 0.5,  0.5,  0.5 ]
        STD: [ 0.5,  0.5,  0.5 ]
        INPUT_KEY: [ 'src' ]
        OUTPUT_KEY: [ 'condition_cat' ]
        BACKEND: torchvision
      - NAME: Select
        KEYS: [ 'condition_cat', 'prompt' ]
        META_KEYS: [ 'image_size' ]
  EVAL_HOOKS:
    -
      NAME: ProbeDataHook
      PROB_INTERVAL: 100
      SAVE_LAST: True
      SAVE_NAME_PREFIX: 'step'
      SAVE_PROBE_PREFIX: 'image'

I fine-tune it on a single GPU, but the loss suddenly increases during training and the generated images turn into pure noise.

Image

scepter [INFO] 2025-07-09 14:21:47,031 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [50/210000], data_time: 0.6301(0.6301), time: 1.1624(1.1624), loss: 0.0817(0.0817), throughput: 312590/day, all_throughput: 100, pg0_lr: 0.000100, scale: 1.000000, [1mins 9secs 0.02%(3days 9hours 13mins 22secs)]
scepter [INFO] 2025-07-09 14:22:14,665 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [100/210000], data_time: 0.009399(0.3198), time: 0.5529(0.8577), loss: 0.0868(0.0843), throughput: 312717/day, all_throughput: 200, pg0_lr: 0.000100, scale: 1.000000, [1mins 37secs 0.05%(2days 8hours 42mins 50secs)]
scepter [INFO] 2025-07-09 14:22:55,082 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [150/210000], data_time: 0.2801(0.3065), time: 0.8083(0.8412), loss: 0.1073(0.0920), throughput: 313775/day, all_throughput: 300, pg0_lr: 0.000100, scale: 1.000000, [2mins 17secs 0.07%(2days 5hours 30mins 23secs)]
scepter [INFO] 2025-07-09 14:23:21,633 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [200/210000], data_time: 0.0122(0.2329), time: 0.5310(0.7637), loss: 0.0968(0.0932), throughput: 316758/day, all_throughput: 400, pg0_lr: 0.000100, scale: 1.000000, [2mins 44secs 0.10%(1days 23hours 51mins 25secs)]
scepter [INFO] 2025-07-09 14:24:00,463 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [250/210000], data_time: 0.2562(0.2376), time: 0.7766(0.7663), loss: 0.1078(0.0961), throughput: 317061/day, all_throughput: 500, pg0_lr: 0.000100, scale: 1.000000, [3mins 23secs 0.12%(1days 23hours 19mins 33secs)]
scepter [INFO] 2025-07-09 14:24:26,817 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [300/210000], data_time: 0.0118(0.2000), time: 0.5268(0.7263), loss: 0.0811(0.0936), throughput: 318919/day, all_throughput: 600, pg0_lr: 0.000100, scale: 1.000000, [3mins 49secs 0.14%(1days 20hours 32mins 46secs)]
scepter [INFO] 2025-07-09 14:25:04,047 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [350/210000], data_time: 0.2275(0.2039), time: 0.7448(0.7290), loss: 0.0910(0.0932), throughput: 319195/day, all_throughput: 700, pg0_lr: 0.000100, scale: 1.000000, [4mins 26secs 0.17%(1days 20hours 22mins 4secs)]
scepter [INFO] 2025-07-09 14:25:31,323 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [400/210000], data_time: 0.0103(0.1797), time: 0.5456(0.7061), loss: 0.1097(0.0953), throughput: 319086/day, all_throughput: 800, pg0_lr: 0.000100, scale: 1.000000, [4mins 53secs 0.19%(1days 18hours 46mins 58secs)]
scepter [INFO] 2025-07-09 14:26:11,375 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [450/210000], data_time: 0.2766(0.1905), time: 0.8009(0.7166), loss: 0.1164(0.0976), throughput: 318814/day, all_throughput: 900, pg0_lr: 0.000100, scale: 1.000000, [5mins 33secs 0.21%(1days 19hours 12mins 3secs)]
scepter [INFO] 2025-07-09 14:26:37,859 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [500/210000], data_time: 0.0120(0.1726), time: 0.5297(0.6979), loss: 0.1375(0.1016), throughput: 319571/day, all_throughput: 1000, pg0_lr: 0.000100, scale: 1.000000, [6mins 0secs 0.24%(1days 17hours 57mins 14secs)]
scepter [INFO] 2025-07-09 14:27:20,358 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [550/210000], data_time: 0.3287(0.1868), time: 0.8498(0.7117), loss: 0.2402(0.1142), throughput: 319437/day, all_throughput: 1100, pg0_lr: 0.000100, scale: 1.000000, [6mins 42secs 0.26%(1days 18hours 37mins 35secs)]
scepter [INFO] 2025-07-09 14:27:46,660 [File: log.py Function: _print_iter_log at line 71] Stage [train] iter: [600/210000], data_time: 0.0126(0.1723), time: 0.5263(0.6963), loss: 0.5065(0.1469), throughput: 320198/day, all_throughput: 1200, pg0_lr: 0.000100, scale: 1.000000, [7mins 9secs 0.29%(1days 17hours 36mins 53secs)]

  1. Are there any recommended tricks for stable training with high LoRA rank?
  2. Is there any recommended gradient clipping or norm scaling I should add?
  3. Should I adjust any part of the UNet config for large resolution + large LoRA rank?

Any advice would be appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions