Skip to content

Commit 70bc6be

Browse files
authored
support jetmoe & fix python id() caused bugs (#998)
1 parent 232b49b commit 70bc6be

File tree

20 files changed

+1706
-12
lines changed

20 files changed

+1706
-12
lines changed

llm/inference/jetmoe/run_jetmoe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import mindspore
2+
from mindnlp.transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3+
# Initialize the model and tokenizer
4+
model_name = "jetmoe/jetmoe-8b-chat"
5+
tokenizer = AutoTokenizer.from_pretrained(model_name)
6+
model = AutoModelForCausalLM.from_pretrained(model_name, ms_dtype=mindspore.float16)
7+
# Encode input context
8+
messages = [
9+
{
10+
"role": "system",
11+
"content": "You are a friendly chatbot",
12+
},
13+
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
14+
]
15+
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="ms")
16+
print(tokenized_chat)
17+
# Generate text
18+
output = model.generate(tokenized_chat, max_length=500, num_return_sequences=1, no_repeat_ngram_size=2)
19+
# Decode the generated text
20+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
21+
print(generated_text)

mindnlp/diffusers/__init__.py

Whitespace-only changes.

mindnlp/diffusers/loaders/__init__.py

Whitespace-only changes.

mindnlp/diffusers/models/__init__.py

Whitespace-only changes.

mindnlp/diffusers/pipelines/__init__.py

Whitespace-only changes.

mindnlp/diffusers/schedulers/__init__.py

Whitespace-only changes.

mindnlp/diffusers/utils/__init__.py

Whitespace-only changes.

mindnlp/injection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from typing import OrderedDict
2020
from functools import reduce, partial
2121
import math
22+
from uuid import uuid4
2223
from packaging import version
24+
2325
import numpy as np
2426
import mindspore
2527
import mindspore.common.dtype as mstype
@@ -410,6 +412,12 @@ def _initialize(self, init_method):
410412

411413
Parameter.initialize = _initialize
412414

415+
old_param_init = Parameter.__init__
416+
def _param_new_init(self, default_input, name=None, requires_grad=True, layerwise_parallel=False, parallel_optimizer=True):
417+
old_param_init(self, default_input, name, requires_grad, layerwise_parallel, parallel_optimizer)
418+
self.uuid = uuid4().hex
419+
420+
Parameter.__init__ = _param_new_init
413421

414422
old_repeat = Tensor.repeat
415423
def new_repeat_interleave(input, repeats, axis=None):

mindnlp/transformers/modeling_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,17 +1048,16 @@ def empty_initializer(init, shape=None, dtype=mindspore.float32):
10481048

10491049
# These are all the pointers of shared tensors.
10501050
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
1051-
10521051
def load_ckpt(resolved_archive_file):
10531052
if 'ckpt' not in resolved_archive_file:
10541053
if use_safetensors or 'safetensors' in resolved_archive_file:
10551054
from safetensors.numpy import load_file
10561055
origin_state_dict = load_file(resolved_archive_file)
10571056
if use_fp16:
10581057
logger.warning_once("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16")
1059-
state_dict = {k: Parameter(v.astype(usage_dtype)) for k, v in origin_state_dict.items()}
1058+
new_state_dict = {k: Parameter(Tensor.from_numpy(v.astype(usage_dtype))) for k, v in origin_state_dict.items()}
10601059
else:
1061-
state_dict = load(resolved_archive_file)
1060+
new_state_dict = load(resolved_archive_file)
10621061
else:
10631062
try:
10641063
state_dict = load_checkpoint(str(resolved_archive_file))
@@ -1067,12 +1066,12 @@ def load_ckpt(resolved_archive_file):
10671066
f"Unable to load weights from mindspore checkpoint file '{resolved_archive_file}'. "
10681067
) from exc
10691068

1070-
new_state_dict = {}
1071-
for key, value in state_dict.items():
1072-
key = key.replace('gamma', 'weight').replace('beta', 'bias').replace('embedding_table', 'weight')
1073-
value.name = value.name.replace('gamma', 'weight').replace('beta', 'bias')\
1074-
.replace('embedding_table', 'weight')
1075-
new_state_dict[key] = value
1069+
new_state_dict = {}
1070+
for key, value in state_dict.items():
1071+
key = key.replace('gamma', 'weight').replace('beta', 'bias').replace('embedding_table', 'weight')
1072+
value.name = value.name.replace('gamma', 'weight').replace('beta', 'bias')\
1073+
.replace('embedding_table', 'weight')
1074+
new_state_dict[key] = value
10761075
return new_state_dict
10771076

10781077
keys_missing = list(model.parameters_dict().keys())
@@ -1114,7 +1113,7 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_gro
11141113
else:
11151114
param_name = pname_in_net
11161115

1117-
if id(param) in param_id_set:
1116+
if param.uuid in param_id_set:
11181117
# for tied params
11191118
if param_name in keys_unexpected:
11201119
keys_unexpected.remove(param_name)
@@ -1161,7 +1160,7 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_gro
11611160
param.set_data(new_param)
11621161
keys_unexpected.remove(param_name)
11631162
keys_missing.remove(pname_in_net)
1164-
param_id_set.add(id(param))
1163+
param_id_set.add(param.uuid)
11651164
else:
11661165
# fix missing value parameter dtype cast.
11671166
if ms_dtype and ms_dtype != param.dtype:
@@ -1358,7 +1357,7 @@ def num_parameters(self, only_trainable=False):
13581357
total = 0
13591358
param_set = set()
13601359
for param in self.get_parameters():
1361-
param_id = id(param)
1360+
param_id = param.uuid
13621361
if param_id not in param_set and (only_trainable or param.requires_grad):
13631362
total += param.size
13641363
param_set.add(param_id)

mindnlp/transformers/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
gpt_pangu,
7272
graphormer,
7373
hubert,
74+
jetmoe,
7475
layoutlm,
7576
layoutlmv2,
7677
llama,
@@ -163,6 +164,7 @@
163164
from .gpt2 import *
164165
from .graphormer import *
165166
from .hubert import *
167+
from .jetmoe import *
166168
from .layoutlm import *
167169
from .layoutlmv2 import *
168170
from .llama import *
@@ -255,6 +257,7 @@
255257
__all__.extend(gpt2.__all__)
256258
__all__.extend(graphormer.__all__)
257259
__all__.extend(hubert.__all__)
260+
__all__.extend(jetmoe.__all__)
258261
__all__.extend(layoutlm.__all__)
259262
__all__.extend(layoutlmv2.__all__)
260263
__all__.extend(llama.__all__)

0 commit comments

Comments
 (0)