Revert "[MPS] Fix binary ops between int32 tensor with int64 scalar (#80220)"

This reverts commit a6556efd5c4a624cccb437bbeb093d1df8b1448b.

Reverted https://github.com/pytorch/pytorch/pull/80220 on behalf of https://github.com/malfet due to Did not push the final version of commit
This commit is contained in:
PyTorch MergeBot
2022-06-24 22:33:55 +00:00
parent e98e7fe428
commit fdd3e20935

View File

@ -38,7 +38,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, other, output}, /*use_scalar_value*/ false);
string key = op_name + getTensorsStringKey({self, other}, /*use_scalar_value*/ false);
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
@ -62,9 +62,6 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
secondaryCastTensor = castMPSTensor(mpsGraph, newCachedGraph->secondaryTensor, common_dtype);
}
newCachedGraph->outputTensor = binaryBlock(newCachedGraph, primaryCastTensor, secondaryCastTensor);
if (output.scalar_type() != common_dtype) {
newCachedGraph->outputTensor = castMPSTensor(mpsGraph, newCachedGraph->outputTensor, output.scalar_type());
}
}
return newCachedGraph;
});