mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] Fix binary ops between int32 tensor with int64 scalar (#80220)"
This reverts commit a6556efd5c4a624cccb437bbeb093d1df8b1448b. Reverted https://github.com/pytorch/pytorch/pull/80220 on behalf of https://github.com/malfet due to Did not push the final version of commit
This commit is contained in:
@ -38,7 +38,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, other, output}, /*use_scalar_value*/ false);
|
||||
string key = op_name + getTensorsStringKey({self, other}, /*use_scalar_value*/ false);
|
||||
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
|
||||
|
||||
if(!cachedGraph) {
|
||||
@ -62,9 +62,6 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
|
||||
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
|
||||
}
|
||||
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
|
||||
if (output.scalar_type() != common_dtype) {
|
||||
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output.scalar_type());
|
||||
}
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user