Skip to content

Commit c746607

Browse files
authored
fix roll not support on CPU (#932)
1 parent 4abb7f7 commit c746607

File tree

6 files changed

+2
-8
lines changed

6 files changed

+2
-8
lines changed

llm/inference/bark/inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import mindspore
33
from IPython.display import Audio
44
from mindnlp.transformers.models.bark import BarkModel, BarkProcessor
5-
mindspore.set_context(pynative_synchronize=True, device_target="CPU")
65

76

87
voice_preset = None

tests/ut/transformers/models/bark/test_modeling_bark.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
BarkSemanticModel,
5959
)
6060

61-
# mindspore.set_context(pynative_synchronize=True)
6261

6362
class BarkSemanticModelTester:
6463
def __init__(

tests/ut/transformers/models/big_bird/test_modeling_big_bird.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
BIG_BIRD_PRETRAINED_MODEL_ARCHIVE_LIST,
5050
)
5151

52-
# mindspore.set_context(pynative_synchronize=True)
5352

5453
class BigBirdModelTester:
5554
def __init__(

tests/ut/transformers/models/chatglm/test_modeling_graph_chatglm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from mindspore._c_expression import _framework_profiler_step_start
1111
from mindspore._c_expression import _framework_profiler_step_end
12-
# mindspore.set_context(pynative_synchronize=True)
1312
# mindspore.set_context(mode=mindspore.GRAPH_MODE, jit_syntax_level=mindspore.LAX, save_graphs=True, save_graphs_path="./graph")
1413

1514
def set_random_seed(seed):

tests/ut/transformers/models/mixtral/test_modeling_mixtral.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from mindnlp.transformers import MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel
3939

40-
# mindspore.set_context(pynative_synchronize=True)
4140

4241
class MixtralModelTester:
4342
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__

tests/ut/transformers/models/reformer/test_modeling_reformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
ReformerTokenizer,
4646
)
4747

48-
mindspore.set_context(pynative_synchronize=True)
4948

5049
class ReformerModelTester:
5150
def __init__(
@@ -275,8 +274,8 @@ def create_and_check_reformer_model_with_attn_mask(
275274
[half_input_ids, ids_tensor((self.batch_size, half_seq_len), self.vocab_size)],
276275
axis=-1,
277276
)
278-
input_ids_roll = ops.roll(input_ids_roll, roll, dims=-1)
279-
attn_mask_roll = ops.roll(attn_mask, roll, dims=-1)
277+
input_ids_roll = mindspore.numpy.roll(input_ids_roll, roll, axis=-1)
278+
attn_mask_roll = mindspore.numpy.roll(attn_mask, roll, axis=-1)
280279

281280
output_padded = model(input_ids_padded, attention_mask=attn_mask)[0][:, :half_seq_len]
282281
output_padded_rolled = model(input_ids_roll, attention_mask=attn_mask_roll)[0][:, roll : half_seq_len + roll]

0 commit comments

Comments
 (0)