@@ -1048,17 +1048,16 @@ def empty_initializer(init, shape=None, dtype=mindspore.float32):
10481048
10491049 # These are all the pointers of shared tensors.
10501050 tied_params = [names for _ , names in ptrs .items () if len (names ) > 1 ]
1051-
10521051 def load_ckpt (resolved_archive_file ):
10531052 if 'ckpt' not in resolved_archive_file :
10541053 if use_safetensors or 'safetensors' in resolved_archive_file :
10551054 from safetensors .numpy import load_file
10561055 origin_state_dict = load_file (resolved_archive_file )
10571056 if use_fp16 :
10581057 logger .warning_once ("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16" )
1059- state_dict = {k : Parameter (v .astype (usage_dtype )) for k , v in origin_state_dict .items ()}
1058+ new_state_dict = {k : Parameter (Tensor . from_numpy ( v .astype (usage_dtype ) )) for k , v in origin_state_dict .items ()}
10601059 else :
1061- state_dict = load (resolved_archive_file )
1060+ new_state_dict = load (resolved_archive_file )
10621061 else :
10631062 try :
10641063 state_dict = load_checkpoint (str (resolved_archive_file ))
@@ -1067,12 +1066,12 @@ def load_ckpt(resolved_archive_file):
10671066 f"Unable to load weights from mindspore checkpoint file '{ resolved_archive_file } '. "
10681067 ) from exc
10691068
1070- new_state_dict = {}
1071- for key , value in state_dict .items ():
1072- key = key .replace ('gamma' , 'weight' ).replace ('beta' , 'bias' ).replace ('embedding_table' , 'weight' )
1073- value .name = value .name .replace ('gamma' , 'weight' ).replace ('beta' , 'bias' )\
1074- .replace ('embedding_table' , 'weight' )
1075- new_state_dict [key ] = value
1069+ new_state_dict = {}
1070+ for key , value in state_dict .items ():
1071+ key = key .replace ('gamma' , 'weight' ).replace ('beta' , 'bias' ).replace ('embedding_table' , 'weight' )
1072+ value .name = value .name .replace ('gamma' , 'weight' ).replace ('beta' , 'bias' )\
1073+ .replace ('embedding_table' , 'weight' )
1074+ new_state_dict [key ] = value
10761075 return new_state_dict
10771076
10781077 keys_missing = list (model .parameters_dict ().keys ())
@@ -1114,7 +1113,7 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_gro
11141113 else :
11151114 param_name = pname_in_net
11161115
1117- if id ( param ) in param_id_set :
1116+ if param . uuid in param_id_set :
11181117 # for tied params
11191118 if param_name in keys_unexpected :
11201119 keys_unexpected .remove (param_name )
@@ -1161,7 +1160,7 @@ def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_gro
11611160 param .set_data (new_param )
11621161 keys_unexpected .remove (param_name )
11631162 keys_missing .remove (pname_in_net )
1164- param_id_set .add (id ( param ) )
1163+ param_id_set .add (param . uuid )
11651164 else :
11661165 # fix missing value parameter dtype cast.
11671166 if ms_dtype and ms_dtype != param .dtype :
@@ -1358,7 +1357,7 @@ def num_parameters(self, only_trainable=False):
13581357 total = 0
13591358 param_set = set ()
13601359 for param in self .get_parameters ():
1361- param_id = id ( param )
1360+ param_id = param . uuid
13621361 if param_id not in param_set and (only_trainable or param .requires_grad ):
13631362 total += param .size
13641363 param_set .add (param_id )
0 commit comments