Compare commits

...

1 Commits

Author SHA1 Message Date
b9caa336a0 Made everything unranked 2024-09-12 15:01:52 -07:00
8 changed files with 71 additions and 70 deletions

View File

@ -28,6 +28,7 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
_executionDescriptor.enableCommitAndContinue = _enableCommitAndContinue;
// Choose level which optimizes for GPU
[_compilationDescriptor disableTypeInference];
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
}

View File

@ -76,7 +76,7 @@ static inline std::string scalarToMetalTypeString(const Tensor& t) {
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = true);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);

View File

@ -691,7 +691,7 @@ MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType data
}
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape) {
return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil];
return [mpsGraph placeholderWithShape:nil dataType:dataType name:nil];
}
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) {

View File

@ -62,7 +62,7 @@ Tensor relu_mps(const Tensor& self) {
@autoreleasepool {
string key = "relu" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
// passing selector of reLUWithTensor on the mpsGraph object
MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil];
@ -103,7 +103,7 @@ Tensor& relu_mps_(Tensor& self) {
@autoreleasepool {
string key = "relu_" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
// passing selector of reLUWithTensor on the mpsGraph object
MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil];
@ -144,7 +144,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
@autoreleasepool {
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to<double>()
shape:@[ @1 ]
@ -195,8 +195,8 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
std::to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* negSlopeTensor = [mpsGraph constantWithScalar:negative_slope.to<double>()
shape:@[ @1 ]
@ -243,7 +243,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
@autoreleasepool {
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* maximumsTensor = [mpsGraph reductionMaximumWithTensor:inputTensor axis:dim name:nil];
MPSGraphTensor* inputTensorSubMax = [mpsGraph subtractionWithPrimaryTensor:inputTensor
@ -334,7 +334,7 @@ std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_mps(const Tensor& self, Ten
@autoreleasepool {
string key = "log_sigmoid_forward_out:" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* minTensor = [mpsGraph minimumWithPrimaryTensor:inputTensor secondaryTensor:zeroTensor name:nil];
MPSGraphTensor* absInputTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil];
@ -393,8 +393,8 @@ Tensor& log_sigmoid_backward_mps_out(const Tensor& grad_output,
@autoreleasepool {
string key = "log_sigmoid_backward_out:" + getTensorsStringKey({self, grad_output});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* negOneTensor = [mpsGraph constantWithScalar:-1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
@ -542,7 +542,7 @@ TORCH_IMPL_FUNC(threshold_out_mps)
":" + std::to_string(value.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to<double>()
shape:@[ @1 ]
@ -589,8 +589,8 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad));
MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to<double>()
shape:@[ @1 ]
@ -684,7 +684,7 @@ TORCH_IMPL_FUNC(gelu_out_mps)(const Tensor& self, c10::string_view approximate,
@autoreleasepool {
const auto key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + gelutype_to_string(approximate_type);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* outputTensor = nil;
if (approximate_type == GeluType::Tanh) {
@ -735,8 +735,8 @@ TORCH_IMPL_FUNC(gelu_backward_out_mps)
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto dataType = getMPSDataType(self);
MPSGraphTensor* gradTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad), getMPSShape(grad));
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(self));
MPSGraphTensor* gradTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, dataType);
MPSGraphTensor* outputTensor = nil;
if (approximate_type == GeluType::Tanh) {
constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * (0.5f);
@ -838,7 +838,7 @@ static void elu_variants_out_mps(const Tensor& self,
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) ))
@ -928,8 +928,8 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* selfOrResultTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self_or_result));
MPSGraphTensor* lessThanZeroGradTensor = nil;
if (is_result) {
@ -1020,7 +1020,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
@autoreleasepool {
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor
numSplits:2
axis:wrap_dim
@ -1062,9 +1062,9 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons
@autoreleasepool {
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_output), getMPSShape(grad_output));
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
NSArray<MPSGraphTensor*>* inputTensorsArray = [mpsGraph splitTensor:inputTensor
numSplits:2
axis:wrap_dim
@ -1147,7 +1147,7 @@ TORCH_IMPL_FUNC(softplus_out_mps)
std::to_string(threshold.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
@ -1218,9 +1218,9 @@ TORCH_IMPL_FUNC(softplus_backward_out_mps)
std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
@ -1292,7 +1292,7 @@ TORCH_IMPL_FUNC(mish_out_mps)
string key = "mish_out_mps:" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:inputTensor name:nil];
MPSGraphTensor* softplusTensor = at::native::mps::log1p(mpsGraph, expTensor);
@ -1337,8 +1337,8 @@ Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
string key = "mish_backward_out_mps:" + getTensorsStringKey({grad_output, self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
// grad_input = grad_output * (tanh(softplus(x)) + input * sigmoid(x) * (1 - tanh(softplus(x)) ^ 2)
@ -1404,7 +1404,7 @@ TORCH_IMPL_FUNC(softshrink_out_mps)
string key = "softshrink_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil];
@ -1473,8 +1473,8 @@ static void shrink_backward_out_mps(const Tensor& grad_output,
string key = op_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* lambdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
MPSGraphTensor* negativeLambdTensor = [mpsGraph negativeWithTensor:lambdTensor name:nil];
@ -1542,9 +1542,9 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
string key = "prelu_mps:" + getTensorsStringKey({self, weight_});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_);
MPSGraphTensor* weightTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(weight_));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil];
@ -1598,11 +1598,11 @@ std::tuple<Tensor, Tensor> prelu_backward_mps(const Tensor& grad_output, const T
string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_);
MPSGraphTensor* weightTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(weight_));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* weightedGradOutputTensor = [mpsGraph multiplicationWithPrimaryTensor:weightTensor
@ -1657,7 +1657,7 @@ TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) {
string key = "silu_out_mps:" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil];
@ -1697,8 +1697,8 @@ TORCH_IMPL_FUNC(silu_backward_out_mps)
string key = "silu_out_backward_mps:" + getTensorsStringKey({grad_output});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(grad_output)];
MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil];
@ -1753,7 +1753,7 @@ TORCH_IMPL_FUNC(hardsigmoid_out_mps)(const Tensor& self, const Tensor& result) {
string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
@ -1794,8 +1794,8 @@ TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps)
string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* highTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:-3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
@ -1867,8 +1867,8 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output,
std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
// TODO: Compute gradient
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f
@ -1942,7 +1942,7 @@ Tensor& hardswish_out_mps(const Tensor& self, Tensor& output) {
@autoreleasepool {
string key = "hardswish_out_mps" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)];
@ -2027,8 +2027,8 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) {
@autoreleasepool {
string key = "hardswish_backward_mps" + getTensorsStringKey({self});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f
shape:@[ @1 ]

View File

@ -60,9 +60,9 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
auto mkey = __func__ + getTensorsStringKey({query, key, value}) + ":" + std::to_string(is_causal) + ":" +
std::to_string(attn_mask.has_value());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(mkey, [&](auto mpsGraph, auto graph) {
auto qTensor = mpsGraphRankedPlaceHolder(mpsGraph, query);
auto kTensor = mpsGraphRankedPlaceHolder(mpsGraph, key);
auto vTensor = mpsGraphRankedPlaceHolder(mpsGraph, value);
auto qTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(query));
auto kTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(key));
auto vTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(value));
auto kT = [mpsGraph transposeTensor:kTensor dimension:2 withDimension:3 name:nil];
auto scaleTensor = [mpsGraph constantWithScalar:scale_factor shape:getMPSShape({1}) dataType:MPSDataTypeFloat32];
@ -96,7 +96,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_mps(const Tensor&
falsePredicateTensor:minusInf
name:nil];
} else if (attn_mask) {
graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask);
graph->maskTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(*attn_mask));
maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil];
}
auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil];

View File

@ -67,8 +67,8 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
@autoreleasepool {
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input));
MPSGraphTensor* weightTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(weight));
MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
dimension:-1

View File

@ -211,7 +211,7 @@ static void reduction_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputScalarType = input_t.scalar_type();
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t), mpsShape);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph,getMPSDataType(input_t));
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
@ -350,10 +350,10 @@ static void impl_func_norm_mps(const Tensor& input_tensor,
keepdim_info + ":" + toString(in_dtype);
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
newCachedGraph->inputTensor_ = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_tensor));
if (cdist) {
newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
newCachedGraph->otherTensor_ = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(other_tensor));
}
MPSGraphTensor* inputTensor = cdist
@ -564,7 +564,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
MPSGraphTensor* outputVarTensor = [mpsGraph varianceOfTensor:inputTensor axes:wrappedAxes name:nil];
MPSGraphTensor* outputTensor = nil;
@ -611,7 +611,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
@autoreleasepool {
string key = func_name + getTensorsStringKey(input_t);
CachedGraph* cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
MPSGraphTensor* castOutputTensor = nil;
MPSGraphTensor* castInputTensor =
@ -688,7 +688,7 @@ static void min_max_out_mps(const Tensor& input_t,
@autoreleasepool {
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
@ -849,7 +849,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputScalarType = input_t.scalar_type();
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType), apparent_in_shape);
mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType));
MPSGraphTensor* argreduceOutTensor = nil;
MPSGraphTensor* castInputTensor = inputTensor;
@ -1203,7 +1203,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
@autoreleasepool {
string key = op_name + "_out_mps:" + getTensorsStringKey(input_t) + ":" + std::to_string(dim_);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
@ -1277,7 +1277,7 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
@autoreleasepool {
string key = string("any_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
@ -1328,7 +1328,7 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
@autoreleasepool {
string key = string("all_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
@ -1426,7 +1426,7 @@ Tensor median_mps(const Tensor& input_t) {
@autoreleasepool {
string key = "median_mps:" + getMPSTypeString(input_t) + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
@ -1499,7 +1499,7 @@ static void median_out_mps(const Tensor& input_t,
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
getTensorsStringKey(indices_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t));
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);

View File

@ -86,7 +86,7 @@ static void unary_op_noresize(const Tensor& self, const Tensor& output_, std::st
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, output});
auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self);
newCachedGraph->inputTensor_ = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(self));
MPSGraphTensor* castTensor = newCachedGraph->inputTensor_;
// Integer input must be cast to float if output is float
if (isIntegralType(self.scalar_type(), true) && isFloatingType(output.scalar_type())) {
@ -373,9 +373,9 @@ TORCH_IMPL_FUNC(logit_backward_out_mps)
(eps.has_value() ? std::to_string(eps.value()) : "-1") + "]";
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* outputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_input);
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input));
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_input));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:inputTensor.dataType];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:eps_ shape:@[ @1 ] dataType:inputTensor.dataType];