Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.language_model
text_kwargs["model"] = model.model.language_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
Expand All @@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.vision_tower
vision_kwargs["model"] = model.model.vision_tower
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
Expand Down Expand Up @@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
# instance variables that reference already-instantiated modules

if isinstance(model, MllamaForConditionalGeneration):
language_model: MllamaForCausalLM = model.language_model
vision_model: MllamaVisionModel = model.vision_model
language_model: MllamaForCausalLM = model.model.language_model
vision_model: MllamaVisionModel = model.model.vision_model
if isinstance(language_model, MllamaForCausalLM):
text_model: MllamaTextModel = language_model.model
else:
Expand Down Expand Up @@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
# instance variables that reference already-instantiated modules

if isinstance(model, Gemma3ForConditionalGeneration):
if isinstance(model.vision_tower, SiglipVisionModel):
vision_tower = model.vision_tower
if isinstance(model.model.vision_tower, SiglipVisionModel):
vision_tower = model.model.vision_tower

_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)

Expand All @@ -1132,15 +1132,15 @@ def apply_liger_kernel_to_gemma3(
raise TypeError("The vision tower must be SiglipVisionModel")

if rms_norm:
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)

apply_liger_kernel_to_gemma3_text(
rope=rope,
cross_entropy=False,
fused_linear_cross_entropy=False,
rms_norm=rms_norm,
geglu=geglu,
model=model.language_model,
model=model.model.language_model,
)

else:
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
if not isinstance(model, PaliGemmaForConditionalGeneration):
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")

vision_tower: SiglipVisionModel = model.vision_tower
vision_tower: SiglipVisionModel = model.model.vision_tower

_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)

Expand All @@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
_patch_layer_norm_module(layer.layer_norm1)
_patch_layer_norm_module(layer.layer_norm2)

language_model = model.language_model
language_model = model.model.language_model

if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
apply_liger_kernel_to_gemma(
Expand Down Expand Up @@ -1520,11 +1520,10 @@ def apply_liger_kernel_to_qwen2_vl(
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
if isinstance(model, Qwen2VLForConditionalGeneration):
text_model: Qwen2VLTextModel = model.model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2VLModel):
text_model: Qwen2VLTextModel = model.language_model
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
elif isinstance(model, Qwen2VLTextModel):
Expand Down Expand Up @@ -1611,11 +1610,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
text_model: Qwen2_5_VLTextModel = model.model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
elif isinstance(model, Qwen2_5_VLModel):
text_model: Qwen2_5_VLTextModel = model.language_model
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
elif isinstance(model, Qwen2_5_VLTextModel):
Expand All @@ -1629,7 +1627,7 @@ def apply_liger_kernel_to_qwen2_5_vl(

if vision_model is not None:
# Patch Qwen2_5_VisionTransformerPretrainedModel
for vision_block in model.visual.blocks:
for vision_block in vision_model.blocks:
if rms_norm:
_patch_rms_norm_module(vision_block.norm1)
_patch_rms_norm_module(vision_block.norm2)
Expand Down Expand Up @@ -1698,7 +1696,9 @@ def apply_liger_kernel_to_qwen3_vl(
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward

if model is not None and rms_norm:
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
if isinstance(model, Qwen3VLForConditionalGeneration):
text_model: Qwen3VLTextModel = model.model.language_model
elif isinstance(model, Qwen3VLModel):
text_model: Qwen3VLTextModel = model.language_model
elif isinstance(model, Qwen3VLTextModel):
text_model = model
Expand Down Expand Up @@ -1773,7 +1773,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward

if model is not None and rms_norm:
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
text_model: Qwen3VLMoeTextModel = model.model.language_model
elif isinstance(model, Qwen3VLMoeModel):
text_model: Qwen3VLMoeTextModel = model.language_model
elif isinstance(model, Qwen3VLMoeTextModel):
text_model = model
Expand Down Expand Up @@ -2118,10 +2120,10 @@ def apply_liger_kernel_to_glm4v(
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
if isinstance(model, Glm4vForConditionalGeneration):
text_model: Glm4vTextModel = model.model.language_model
vision_model: Glm4vVisionModel = model.model.visual
elif isinstance(model, Glm4vModel):
text_model: Glm4vTextModel = model.language_model
vision_model: Glm4vVisionModel = model.visual
elif isinstance(model, Glm4vTextModel):
Expand Down Expand Up @@ -2208,10 +2210,11 @@ def apply_liger_kernel_to_glm4v_moe(
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
# Note: language_model and visual properties can be accessed throught conditional class for BC.
# Not sure if it is subject to changes in the future.
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
if isinstance(model, Glm4vMoeForConditionalGeneration):
text_model: Glm4vMoeTextModel = model.model.language_model
vision_model: Glm4vMoeVisionModel = model.model.visual
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
elif isinstance(model, Glm4vMoeModel):
text_model: Glm4vMoeTextModel = model.language_model
vision_model: Glm4vMoeVisionModel = model.visual
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
Expand Down Expand Up @@ -2314,8 +2317,10 @@ def apply_liger_kernel_to_internvl(
if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
# NOTE: language_model and visual properties can be accessed throught conditional class.
if isinstance(model, InternVLForConditionalGeneration):
text_model = model.model.language_model
vision_model: InternVLVisionModel = model.model.vision_tower
elif isinstance(model, InternVLModel):
text_model = model.language_model
vision_model: InternVLVisionModel = model.vision_tower
else:
Expand Down