Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,9 @@ def apply_liger_kernel_to_llava(
from transformers.models.llava import modeling_llava

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if transformer_version >= version.parse("4.52.0"):
if model is not None:
Expand Down Expand Up @@ -494,7 +495,9 @@ def apply_liger_kernel_to_llama4(
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP

if cross_entropy:
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
Expand Down Expand Up @@ -686,7 +689,9 @@ def apply_liger_kernel_to_mistral(
if rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if transformer_version >= version.parse("4.49.0"):
if model is not None:
Expand Down Expand Up @@ -1099,7 +1104,9 @@ def apply_liger_kernel_to_gemma3(
)

if cross_entropy:
modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
Expand Down Expand Up @@ -1201,7 +1208,9 @@ def apply_liger_kernel_to_paligemma(
)
# Handle loss function
if cross_entropy:
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
if model is not None:
Expand Down Expand Up @@ -1502,7 +1511,9 @@ def apply_liger_kernel_to_qwen2_vl(
if layer_norm and model is None:
modeling_qwen2_vl.LayerNorm = LigerLayerNorm
if cross_entropy:
modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen2_vl_lce_forward, model)
Expand Down Expand Up @@ -1593,7 +1604,9 @@ def apply_liger_kernel_to_qwen2_5_vl(
if rms_norm:
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
if cross_entropy:
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(qwen2_5_vl_lce_forward, model)
Expand Down Expand Up @@ -2063,8 +2076,9 @@ def apply_liger_kernel_to_internvl(
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
if rms_norm:
Expand Down
147 changes: 147 additions & 0 deletions test/transformers/test_monkey_patch_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Test cross_entropy monkey patches for all supported models.

Note: This test uses subprocess isolation because cross_entropy patches modify
a global function (transformers.loss.loss_utils.nn.functional.cross_entropy).
Once patched by any model, it affects all subsequent tests in the same process,
making it impossible to verify individual model patches independently.

By running each test in a separate Python process, we ensure that:
1. Each model's patch is tested in isolation
2. Failures can be correctly attributed to specific models
3. The test suite can detect when a patch is incorrectly targeting the wrong object

Trade-off: ~20x slower (60s vs 3s) but provides accurate per-model validation.
"""

import importlib
import inspect
import subprocess
import sys

import pytest
import transformers

from packaging import version

transformer_version = version.parse(transformers.__version__)
SUPPORTED_TRANSFORMER_VERSION = "4.46.1"


def _extract_model_configs():
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN

configs = []
seen_functions = set()

for model_type, apply_fn in MODEL_TYPE_TO_APPLY_LIGER_FN.items():
if apply_fn in seen_functions:
continue
seen_functions.add(apply_fn)

fn_name = apply_fn.__name__
model_name = fn_name.replace("apply_liger_kernel_to_", "")

sig = inspect.signature(apply_fn)
if "cross_entropy" not in sig.parameters:
continue

transformers_module = f"transformers.models.{model_name}"

configs.append(
{
"name": model_name,
"module": transformers_module,
"apply_fn_name": fn_name,
}
)

return configs


MODEL_CONFIGS = _extract_model_configs()


def is_model_available(module_name):
try:
importlib.import_module(module_name)
return True
except ImportError:
return False


def should_skip_model(model_config):
if transformer_version < version.parse(SUPPORTED_TRANSFORMER_VERSION):
return True, f"transformers version {transformer_version} < {SUPPORTED_TRANSFORMER_VERSION}"
if not is_model_available(model_config["module"]):
return True, f"{model_config['name']} not available"
return False, None


ISOLATED_TEST_SCRIPT = '''
import sys
import torch.nn.functional

def test_single_model_patch():
from liger_kernel.transformers import monkey_patch

apply_fn_name = "{apply_fn_name}"
model_name = "{model_name}"

from transformers.loss import loss_utils
original_ce = torch.nn.functional.cross_entropy

if loss_utils.nn.functional.cross_entropy != original_ce:
print(f"FAIL: cross_entropy was already patched before testing {{model_name}}")
sys.exit(1)

apply_fn = getattr(monkey_patch, apply_fn_name)

try:
apply_fn(cross_entropy=True, fused_linear_cross_entropy=False)
except Exception as e:
print(f"FAIL: Failed to apply patch: {{e}}")
sys.exit(1)

patched_ce = loss_utils.nn.functional.cross_entropy

if patched_ce == original_ce:
print(f"FAIL: cross_entropy was not patched")
sys.exit(1)

if "liger" not in patched_ce.__module__.lower():
print(f"FAIL: cross_entropy module is {{patched_ce.__module__}}, expected liger")
sys.exit(1)

print(f"PASS: {{model_name}} patched correctly to {{patched_ce.__module__}}")
sys.exit(0)

if __name__ == "__main__":
test_single_model_patch()
'''


@pytest.mark.parametrize("model_config", MODEL_CONFIGS, ids=[m["name"] for m in MODEL_CONFIGS])
def test_cross_entropy_patch(model_config):
should_skip, reason = should_skip_model(model_config)
if should_skip:
pytest.skip(reason)

script = ISOLATED_TEST_SCRIPT.format(
apply_fn_name=model_config["apply_fn_name"],
model_name=model_config["name"],
)

result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
timeout=30,
)

output = result.stdout + result.stderr

if result.returncode != 0:
pytest.fail(f"{model_config['name']} test failed:\n{output}")

assert "PASS" in output, f"{model_config['name']}: Unexpected output:\n{output}"