mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Use fmt::format
to define Conv key (#162925)
Also use `getArrayRefString` instead of having separate cases for 2D and 3D Conv Pull Request resolved: https://github.com/pytorch/pytorch/pull/162925 Approved by: https://github.com/Skylion007 ghstack dependencies: #162921
This commit is contained in:
committed by
PyTorch MergeBot
parent
7fe1f5ea49
commit
76e5df3866
@ -6,6 +6,7 @@
|
||||
#include <ATen/ops/_mps_convolution_transpose_native.h>
|
||||
#include <ATen/ops/mps_convolution_backward_native.h>
|
||||
#include <ATen/ops/mps_convolution_transpose_backward_native.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
@ -172,18 +173,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
if (bias_defined)
|
||||
bias_shape = bias_opt.value().sizes();
|
||||
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
case at::MemoryFormat::ChannelsLast:
|
||||
mem_format_key = "ChannelsLast";
|
||||
break;
|
||||
default:
|
||||
assert(0 && "Check should have been done earlier\n");
|
||||
}
|
||||
|
||||
std::string bias_shape_key;
|
||||
if (bias_defined) {
|
||||
bias_shape_key = std::to_string(bias_shape[0]);
|
||||
@ -191,20 +180,16 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
|
||||
} else {
|
||||
key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
}
|
||||
std::string key = fmt::format("mps_{}convolution:{}:{}:{}:{}:{}:{}:{}:{}",
|
||||
is3DConv ? "3d_" : "",
|
||||
getArrayRefString(stride),
|
||||
getArrayRefString(dilation),
|
||||
getArrayRefString(padding),
|
||||
groups,
|
||||
is_channels_last,
|
||||
mps::getTensorsStringKey({input_t, weight_t}),
|
||||
bias_defined,
|
||||
bias_shape_key);
|
||||
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
|
||||
MPSShape* outputShape = mps::getMPSShape(output_t, memory_format);
|
||||
@ -386,33 +371,15 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
@autoreleasepool {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
break;
|
||||
case at::MemoryFormat::ChannelsLast:
|
||||
mem_format_key = "ChannelsLast";
|
||||
break;
|
||||
default:
|
||||
assert(0 && "Check should have been done earlier\n");
|
||||
}
|
||||
|
||||
MPSShape* mps_input_shape = getMPSShape(input_size);
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t});
|
||||
|
||||
} else {
|
||||
key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t});
|
||||
}
|
||||
std::string key = fmt::format("mps_{}_convolution_backward_input:{}:{}:{}:{}:{}:{}",
|
||||
is3DConv ? "3d_" : "",
|
||||
getArrayRefString(stride),
|
||||
getArrayRefString(dilation),
|
||||
getArrayRefString(padding),
|
||||
groups,
|
||||
is_channels_last,
|
||||
getTensorsStringKey({grad_output_t, weight_t}));
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output_t);
|
||||
auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
@ -537,19 +504,13 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
MPSShape* mps_weight_shape = getMPSShape(weight_size);
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t});
|
||||
} else {
|
||||
key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t});
|
||||
}
|
||||
std::string key = fmt::format("mps_{}convolution_backward_weights:{}:{}:{}:{}:{}",
|
||||
is3DConv ? "3d_" : "",
|
||||
getArrayRefString(stride),
|
||||
getArrayRefString(dilation),
|
||||
getArrayRefString(padding),
|
||||
groups,
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}));
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSShape* inputShape = getMPSShape(input_t);
|
||||
bool isDepthwiseConv =
|
||||
|
Reference in New Issue
Block a user