Skip to content

Commit fdb923c

Browse files
committed
fix Trainer.train raise div not support inplace error
1 parent b0f7439 commit fdb923c

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

mindtorch/_apis/npu.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1716,9 +1716,14 @@ def outer(input, other):
17161716

17171717
def addcmul(input, tensor1, tensor2, value=1.0):
17181718
if use_pyboost() and not ON_ORANGE_PI:
1719-
return pyboost.addcmul_op(input, tensor1, tensor2, value)
1719+
return pyboost.addcmul_op(input, tensor1, tensor2, mindspore.Tensor(value))
17201720
return legacy.add(mul(mul(tensor1, tensor2), value), input)
17211721

1722+
def addcdiv(input, tensor1, tensor2, value=1.0):
1723+
if use_pyboost() and not ON_ORANGE_PI:
1724+
return pyboost.addcdiv_op(input, tensor1, tensor2, mindspore.Tensor(value))
1725+
return legacy.add(div(mul(tensor1, tensor2), value), input)
1726+
17221727
def prelu(input, weight):
17231728
if use_pyboost():
17241729
return pyboost.prelu_op(input, weight)

mindtorch/_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,10 +340,10 @@ def __abs__(self):
340340
return ops.abs(self)
341341

342342
def __imul__(self, other):
343-
return self.copy_(ops.mul(self, other))
343+
return ops.mul(self, other)
344344

345345
def __itruediv__(self, other):
346-
return self.copy_(ops.div(self, other))
346+
return ops.div(self, other)
347347

348348
def __pow__(self, other):
349349
return ops.pow(self, other)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run(self):
6464
_create_namespace_links() # 安装后创建链接
6565

6666

67-
version = '0.5.0'
67+
version = '0.5.1'
6868
cur_dir = os.path.dirname(os.path.realpath(__file__))
6969
pkg_dir = os.path.join(cur_dir, 'build')
7070

@@ -170,7 +170,7 @@ def run(self):
170170
},
171171

172172
install_requires=[
173-
'mindspore>=2.5.0',
173+
'mindspore>=2.5.0, <=2.7.0',
174174
'tqdm',
175175
'requests',
176176
'accelerate>=1.6.0', # hf dependency

0 commit comments

Comments
 (0)