Skip to content

Commit 0e55e36

Browse files
authored
[Model] Support Llama4 (#3336)
Added Support for LLama 4
1 parent 8fa69dc commit 0e55e36

File tree

8 files changed

+1045
-0
lines changed

8 files changed

+1045
-0
lines changed

python/mlc_llm/conversation_template/llama.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@
44

55
from .registry import ConvTemplateRegistry
66

7+
# Llama4 - same as Llama3.1 except naming has changed slightly
8+
ConvTemplateRegistry.register_conv_template(
9+
Conversation(
10+
name="llama-4",
11+
system_template="",
12+
system_message="",
13+
roles={
14+
"user": "<|header_start|>user",
15+
"assistant": "<|header_start|>assistant",
16+
"tool": "<|header_start|>ipython",
17+
},
18+
seps=["<|eot|>"],
19+
role_content_sep="<|header_end|>\n\n",
20+
role_empty_sep="<|header_end|>\n\n",
21+
stop_str=[],
22+
stop_token_ids=[200001, 200007, 200008], # "<|end_of_text|>", "<|eom|>", "<|eot|>"
23+
system_prefix_token_ids=[200000], # "<|begin_of_text|>"
24+
add_role_after_system_message=False,
25+
)
26+
)
27+
728
# Llama3.1 -- same as Llama3 except stop token ids and stop str
829
ConvTemplateRegistry.register_conv_template(
930
Conversation(

python/mlc_llm/interface/gen_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
262262
# FIXME: Copy RWKV tokenizer file # pylint: disable=fixme
263263

264264
CONV_TEMPLATES = {
265+
"llama-4",
265266
"llama-3",
266267
"llama-3_1",
267268
"chatml",

python/mlc_llm/model/llama4/__init__.py

Whitespace-only changes.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace
3+
PyTorch, HuggingFace safetensors.
4+
"""
5+
6+
import functools
7+
8+
import numpy as np
9+
10+
from mlc_llm.loader import ExternMapping
11+
from mlc_llm.quantization import Quantization
12+
13+
from .llama4_model import Llama4Config, Llama4ForCausalLM
14+
15+
16+
def huggingface(model_config: Llama4Config, quantization: Quantization) -> ExternMapping:
17+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
18+
the names of HuggingFace PyTorch parameters.
19+
20+
Parameters
21+
----------
22+
model_config : Llama4Config
23+
The configuration of the Llama model.
24+
25+
quantization : Quantization
26+
The quantization configuration.
27+
28+
Returns
29+
-------
30+
param_map : ExternMapping
31+
The parameter mapping from MLC to HuggingFace PyTorch.
32+
"""
33+
model = Llama4ForCausalLM(model_config)
34+
if quantization is not None:
35+
model.to(quantization.model_dtype)
36+
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
37+
spec=model.get_default_spec(),
38+
allow_extern=True,
39+
)
40+
named_parameters = dict(_named_params)
41+
42+
mapping = ExternMapping()
43+
44+
for i in range(model_config.text_config.num_hidden_layers):
45+
# Add shared expert weights
46+
mlp = f"model.layers.{i}.feed_forward.shared_expert"
47+
mlc_name = f"{mlp}.gate_up_proj.weight"
48+
mlc_param = named_parameters[mlc_name]
49+
mapping.add_mapping(
50+
mlc_name,
51+
[
52+
f"language_model.{mlp}.gate_proj.weight",
53+
f"language_model.{mlp}.up_proj.weight",
54+
],
55+
functools.partial(
56+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
57+
dtype=mlc_param.dtype,
58+
),
59+
)
60+
61+
# Add router weights
62+
mlp = f"model.layers.{i}.feed_forward"
63+
mlc_name = f"{mlp}.router.router.weight"
64+
hf_name = f"language_model.{mlp}.router.weight"
65+
mlc_param = named_parameters[mlc_name]
66+
mapping.add_mapping(
67+
mlc_name,
68+
[
69+
hf_name,
70+
],
71+
functools.partial(
72+
lambda x, dtype: x.astype(dtype),
73+
dtype=mlc_param.dtype,
74+
),
75+
)
76+
77+
# Add experts weights
78+
mlp = f"model.layers.{i}.feed_forward"
79+
hf_name = f"language_model.{mlp}.experts.gate_up_proj"
80+
mlc_name = f"{mlp}.experts.gate_up_proj"
81+
mlc_param = named_parameters[mlc_name]
82+
mapping.add_mapping(
83+
mlc_name,
84+
[
85+
hf_name,
86+
],
87+
functools.partial(
88+
lambda x, dtype: x.astype(dtype),
89+
dtype=mlc_param.dtype,
90+
),
91+
)
92+
93+
mlp = f"model.layers.{i}.feed_forward"
94+
mlc_name = f"{mlp}.experts.down_proj"
95+
hf_name = f"language_model.{mlp}.experts.down_proj"
96+
97+
mlc_param = named_parameters[mlc_name]
98+
mapping.add_mapping(
99+
mlc_name,
100+
[
101+
hf_name,
102+
],
103+
functools.partial(
104+
lambda x, dtype: x.astype(dtype),
105+
dtype=mlc_param.dtype,
106+
),
107+
)
108+
109+
for mlc_name, mlc_param in named_parameters.items():
110+
if mlc_name not in mapping.param_map:
111+
mapping.add_mapping(
112+
mlc_name,
113+
[f"language_model.{mlc_name}"],
114+
functools.partial(
115+
lambda x, dtype: x.astype(dtype),
116+
dtype=mlc_param.dtype,
117+
),
118+
)
119+
return mapping

0 commit comments

Comments
 (0)