mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/N] Deprecate c10::string_view and at::string (#151972)
The calls of `c10::string_view` in the code base are replaced by `std::string_view`. The calls of `at::string` are replaced by `std::string` Pull Request resolved: https://github.com/pytorch/pytorch/pull/151972 Approved by: https://github.com/malfet
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace c10;
|
||||
using std::string;
|
||||
|
||||
// NOLINTBEGIN(performance-move-const-arg, bugprone-use-after-move, *analyzer*Move)
|
||||
TEST(ListTestIValueBasedList, givenEmptyList_whenCallingEmpty_thenReturnsTrue) {
|
||||
|
@ -519,7 +519,7 @@ TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernelWithDictInput
|
||||
EXPECT_EQ(2, captured_dict_size);
|
||||
}
|
||||
|
||||
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
std::string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
|
||||
@ -581,7 +581,7 @@ TEST(OperatorRegistrationTestLegacyFunctionBasedKernel, givenKernelWithUnordered
|
||||
EXPECT_EQ(2, captured_dict_size);
|
||||
}
|
||||
|
||||
string kernelWithUnorderedMapInputWithOutput(std::unordered_map<string, string> input1) {
|
||||
std::string kernelWithUnorderedMapInputWithOutput(std::unordered_map<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
|
||||
|
@ -468,7 +468,7 @@ TEST(OperatorRegistrationTestFunctionBasedKernel, givenKernelWithDictInput_witho
|
||||
EXPECT_EQ(2, captured_dict_size);
|
||||
}
|
||||
|
||||
string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
std::string kernelWithDictInputWithOutput(Dict<string, string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
|
||||
|
@ -463,7 +463,7 @@ TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernelWithDictInput_withou
|
||||
}
|
||||
|
||||
struct KernelWithDictInputWithOutput final : OperatorKernel {
|
||||
string operator()(Dict<string, string> input1) {
|
||||
std::string operator()(Dict<string, std::string> input1) {
|
||||
return input1.at("key2");
|
||||
}
|
||||
};
|
||||
@ -475,7 +475,7 @@ TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernelWithDictInput_withOu
|
||||
auto op = c10::Dispatcher::singleton().findSchema({"_test::dict_input", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
Dict<string, std::string> dict;
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
@ -484,7 +484,7 @@ TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernelWithDictInput_withOu
|
||||
}
|
||||
|
||||
struct KernelWithDictOutput final : OperatorKernel {
|
||||
Dict<string, string> operator()(Dict<string, string> input) {
|
||||
Dict<string, std::string> operator()(Dict<string, std::string> input) {
|
||||
return input;
|
||||
}
|
||||
};
|
||||
@ -496,12 +496,12 @@ TEST(OperatorRegistrationTestFunctorBasedKernel, givenKernelWithDictOutput_whenR
|
||||
auto op = c10::Dispatcher::singleton().findSchema({"_test::dict_output", ""});
|
||||
ASSERT_TRUE(op.has_value());
|
||||
|
||||
Dict<string, string> dict;
|
||||
Dict<string, std::string> dict;
|
||||
dict.insert("key1", "value1");
|
||||
dict.insert("key2", "value2");
|
||||
auto outputs = callOp(*op, dict);
|
||||
EXPECT_EQ(1, outputs.size());
|
||||
auto output = c10::impl::toTypedDict<string, string>(outputs[0].toGenericDict());
|
||||
auto output = c10::impl::toTypedDict<string, std::string>(outputs[0].toGenericDict());
|
||||
|
||||
EXPECT_EQ(2, output.size());
|
||||
EXPECT_EQ("value1", output.at("key1"));
|
||||
@ -520,7 +520,7 @@ private:
|
||||
};
|
||||
|
||||
struct KernelWithTupleInput final : OperatorKernel {
|
||||
string operator()(std::tuple<string, int64_t, double> input1) {
|
||||
std::string operator()(std::tuple<string, int64_t, double> input1) {
|
||||
return std::get<0>(input1);
|
||||
}
|
||||
};
|
||||
|
@ -22,7 +22,7 @@ namespace at::native {
|
||||
|
||||
// NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
|
||||
struct MathOpFallback {
|
||||
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
|
||||
MathOpFallback(DispatchKey key_, std::string op_name_) : key(key_), op_name(std::move(op_name_)) {}
|
||||
virtual bool is_bit_set(const Tensor&) = 0;
|
||||
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
/*
|
||||
@ -151,7 +151,7 @@ struct MathOpFallback {
|
||||
virtual ~MathOpFallback() = default;
|
||||
|
||||
DispatchKey key;
|
||||
string op_name;
|
||||
std::string op_name;
|
||||
};
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1735,7 +1735,7 @@ Tensor& index_select_out_cuda(
|
||||
int64_t dim,
|
||||
const Tensor& index,
|
||||
Tensor& out) {
|
||||
static constexpr string_view DIM_WARNING =
|
||||
static constexpr std::string_view DIM_WARNING =
|
||||
"Tensor too large or too many (> 25) dimensions";
|
||||
TORCH_CHECK(
|
||||
at::cuda::check_device({out, self, index}),
|
||||
|
@ -772,7 +772,7 @@ void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
c10::string_view arch(dprops->gcnArchName);
|
||||
std::string_view arch(dprops->gcnArchName);
|
||||
if (arch == "gfx1100") {
|
||||
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
} else{
|
||||
|
@ -101,11 +101,11 @@ void quantized_matmul(
|
||||
std::optional<at::Tensor> other, // extra input for binary-post-op
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
const c10::string_view& binary_post_op,
|
||||
const std::string_view& binary_post_op,
|
||||
double binary_alpha,
|
||||
const c10::string_view& unary_post_op,
|
||||
const std::string_view& unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm,
|
||||
std::string_view unary_post_op_algorithm,
|
||||
bool m2_trans) {
|
||||
// [Note] Quantized Matrix Multiplication at XPU
|
||||
// The following code integrates oneDNN quantized gemm. The quantization
|
||||
|
@ -156,11 +156,11 @@ void quantized_matmul(
|
||||
std::optional<at::Tensor> other, // extra input for binary-post-op
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
const c10::string_view& binary_post_op,
|
||||
const std::string_view& binary_post_op,
|
||||
double binary_alpha,
|
||||
const c10::string_view& unary_post_op,
|
||||
const std::string_view& unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm,
|
||||
std::string_view unary_post_op_algorithm,
|
||||
bool m2_trnas);
|
||||
|
||||
void gpu_float_sdpa(
|
||||
|
@ -151,11 +151,11 @@ static Tensor q_linear_pointwise_binary(
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op,
|
||||
std::string_view binary_post_op,
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op,
|
||||
std::string_view unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
std::string_view unary_post_op_algorithm) {
|
||||
TORCH_CHECK(
|
||||
act.device() == weight.device() &&
|
||||
act.device() == weight_scales.device() &&
|
||||
@ -222,11 +222,11 @@ static Tensor q_linear_pointwise_binary_tensor(
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op,
|
||||
std::string_view binary_post_op,
|
||||
double binary_alpha,
|
||||
c10::string_view unary_post_op,
|
||||
std::string_view unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
std::string_view unary_post_op_algorithm) {
|
||||
return q_linear_pointwise_binary(
|
||||
act,
|
||||
act_scale.item().toDouble(),
|
||||
|
@ -156,7 +156,7 @@ MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase&
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType);
|
||||
MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, const Scalar& scalar);
|
||||
|
||||
string get_mem_format_string(c10::MemoryFormat memory_format);
|
||||
std::string get_mem_format_string(c10::MemoryFormat memory_format);
|
||||
|
||||
using MPSCacheKey = uint64_t;
|
||||
|
||||
|
@ -760,8 +760,8 @@ MPSGraphTensor* convertNHWCtoNCHW(MPSGraph* mpsGraph, MPSGraphTensor* tensor) {
|
||||
name:nil];
|
||||
}
|
||||
|
||||
string get_mem_format_string(c10::MemoryFormat memory_format) {
|
||||
string mem_format_key;
|
||||
std::string get_mem_format_string(c10::MemoryFormat memory_format) {
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
|
@ -60,7 +60,7 @@ Tensor relu_mps(const Tensor& self) {
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
string key = "relu" + getTensorsStringKey({self});
|
||||
std::string key = "relu" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
// passing selector of reLUWithTensor on the mpsGraph object
|
||||
@ -101,7 +101,7 @@ Tensor& relu_mps_(Tensor& self) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "relu_" + getTensorsStringKey({self});
|
||||
std::string key = "relu_" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
// passing selector of reLUWithTensor on the mpsGraph object
|
||||
@ -142,7 +142,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
|
||||
std::string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -192,7 +192,7 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
|
||||
std::string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
|
||||
std::to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -241,7 +241,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
std::string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -284,7 +284,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim);
|
||||
std::string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
|
||||
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
|
||||
@ -332,7 +332,7 @@ std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_mps(const Tensor& self, Ten
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_sigmoid_forward_out:" + getTensorsStringKey({self});
|
||||
std::string key = "log_sigmoid_forward_out:" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:inputTensor.dataType];
|
||||
@ -391,7 +391,7 @@ Tensor& log_sigmoid_backward_mps_out(const Tensor& grad_output,
|
||||
Tensor grad_input_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_sigmoid_backward_out:" + getTensorsStringKey({self, grad_output});
|
||||
std::string key = "log_sigmoid_backward_out:" + getTensorsStringKey({self, grad_output});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -459,7 +459,7 @@ TORCH_IMPL_FUNC(sigmoid_backward_out_mps)(const Tensor& grad_output, const Tenso
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "sigmoid_backward_out_mps:" + getMPSTypeString(grad_output);
|
||||
std::string key = "sigmoid_backward_out_mps:" + getMPSTypeString(grad_output);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
|
||||
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
|
||||
@ -501,7 +501,7 @@ TORCH_IMPL_FUNC(tanh_backward_out_mps)(const Tensor& grad_output, const Tensor&
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "tanh_backward_out_mps:" + getMPSTypeString(grad_output);
|
||||
std::string key = "tanh_backward_out_mps:" + getMPSTypeString(grad_output);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
|
||||
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
|
||||
@ -538,7 +538,7 @@ TORCH_IMPL_FUNC(threshold_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) +
|
||||
std::string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) +
|
||||
":" + std::to_string(value.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -585,7 +585,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
std::string key =
|
||||
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -815,7 +815,7 @@ static void elu_variants_out_mps(const Tensor& self,
|
||||
const Scalar& scale,
|
||||
const Scalar& input_scale,
|
||||
const Tensor& result,
|
||||
string func_name) {
|
||||
std::string func_name) {
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
@ -834,7 +834,7 @@ static void elu_variants_out_mps(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
|
||||
std::string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
|
||||
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -923,7 +923,7 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
|
||||
std::string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
|
||||
std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" +
|
||||
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
|
||||
|
||||
@ -1018,7 +1018,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
std::string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor
|
||||
@ -1060,7 +1060,7 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
|
||||
std::string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
MPSGraphTensor* gradOutputTensor =
|
||||
@ -1143,8 +1143,8 @@ TORCH_IMPL_FUNC(softplus_out_mps)
|
||||
MPSScalar threshold_scalar = getMPSScalar(threshold, self.scalar_type());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) + ":" +
|
||||
std::to_string(threshold.to<double>());
|
||||
std::string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) +
|
||||
":" + std::to_string(threshold.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1214,7 +1214,7 @@ TORCH_IMPL_FUNC(softplus_backward_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" +
|
||||
std::string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" +
|
||||
std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -1289,7 +1289,7 @@ TORCH_IMPL_FUNC(mish_out_mps)
|
||||
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mish_out_mps:" + getTensorsStringKey({self});
|
||||
std::string key = "mish_out_mps:" + getTensorsStringKey({self});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1334,7 +1334,7 @@ Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mish_backward_out_mps:" + getTensorsStringKey({grad_output, self});
|
||||
std::string key = "mish_backward_out_mps:" + getTensorsStringKey({grad_output, self});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -1401,7 +1401,7 @@ TORCH_IMPL_FUNC(softshrink_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "softshrink_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
|
||||
std::string key = "softshrink_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1476,7 +1476,7 @@ static void shrink_backward_out_mps(const Tensor& grad_output,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
|
||||
std::string key = op_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(lambd.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -1545,7 +1545,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "prelu_mps:" + getTensorsStringKey({self, weight_});
|
||||
std::string key = "prelu_mps:" + getTensorsStringKey({self, weight_});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1601,7 +1601,7 @@ std::tuple<Tensor, Tensor> prelu_backward_mps(const Tensor& grad_output, const T
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_});
|
||||
std::string key = "prelu_backward_mps:" + getTensorsStringKey({grad_output, self, weight_});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -1665,7 +1665,7 @@ TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) {
|
||||
Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "silu_out_mps:" + getTensorsStringKey({self});
|
||||
std::string key = "silu_out_mps:" + getTensorsStringKey({self});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1709,7 +1709,7 @@ TORCH_IMPL_FUNC(silu_backward_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "silu_out_backward_mps:" + getTensorsStringKey({grad_output});
|
||||
std::string key = "silu_out_backward_mps:" + getTensorsStringKey({grad_output});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1765,7 +1765,7 @@ TORCH_IMPL_FUNC(hardsigmoid_out_mps)(const Tensor& self, const Tensor& result) {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self});
|
||||
std::string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1806,7 +1806,7 @@ TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self});
|
||||
std::string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -1878,7 +1878,7 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
|
||||
std::string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
|
||||
std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -1955,7 +1955,7 @@ Tensor& hardswish_out_mps(const Tensor& self, Tensor& output) {
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardswish_out_mps" + getTensorsStringKey({self});
|
||||
std::string key = "hardswish_out_mps" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -2040,7 +2040,7 @@ Tensor hardswish_backward_mps(const Tensor& grad_output, const Tensor& self) {
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardswish_backward_mps" + getTensorsStringKey({self});
|
||||
std::string key = "hardswish_backward_mps" + getTensorsStringKey({self});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
@ -118,7 +118,7 @@ static void binaryOpTensor(const Tensor& self,
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, other, output_});
|
||||
std::string key = op_name + getTensorsStringKey({self, other, output_});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<BinaryOpCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->primaryTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(inputDataType), getMPSShape(self));
|
||||
@ -200,7 +200,7 @@ static void div_mode_template(const Tensor& self,
|
||||
const Tensor& other,
|
||||
std::optional<std::string_view> rounding_mode,
|
||||
const Tensor& output,
|
||||
const string op_name) {
|
||||
const std::string& op_name) {
|
||||
if (rounding_mode.has_value() && *rounding_mode == "trunc") {
|
||||
TORCH_CHECK(self.scalar_type() != ScalarType::Half, "MPS: does not support trunc_divide op with float16 input");
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "dot_mps" + getTensorsStringKey({self, other});
|
||||
std::string key = "dot_mps" + getTensorsStringKey({self, other});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -143,7 +143,7 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" +
|
||||
std::string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" +
|
||||
std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
|
||||
|
@ -33,7 +33,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
|
||||
std::string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
|
||||
|
@ -190,7 +190,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
if (bias_defined)
|
||||
bias_shape = bias_opt.value().sizes();
|
||||
|
||||
string mem_format_key;
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
@ -202,14 +202,14 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
||||
assert(0 && "Check should have been done earlier\n");
|
||||
}
|
||||
|
||||
string bias_shape_key;
|
||||
std::string bias_shape_key;
|
||||
if (bias_defined) {
|
||||
bias_shape_key = std::to_string(bias_shape[0]);
|
||||
} else {
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
string key;
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
@ -404,7 +404,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
@autoreleasepool {
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
string mem_format_key;
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
@ -417,7 +417,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
}
|
||||
|
||||
MPSShape* mps_input_shape = getMPSShape(input_size);
|
||||
string key;
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
@ -555,7 +555,7 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
MPSShape* mps_weight_shape = getMPSShape(weight_size);
|
||||
string key;
|
||||
std::string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
|
@ -48,7 +48,7 @@ static void copy_cast_mps(at::Tensor& dst,
|
||||
|
||||
@autoreleasepool {
|
||||
const bool needs_conj = src.is_conj() != dst.is_conj();
|
||||
string key = "copy_cast_mps" + getTensorsStringKey({src, dst}, true, /*exclude_shape*/ true) + ":" +
|
||||
std::string key = "copy_cast_mps" + getTensorsStringKey({src, dst}, true, /*exclude_shape*/ true) + ":" +
|
||||
std::to_string(needs_conj);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, srcDType);
|
||||
|
@ -482,7 +482,7 @@ static Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample);
|
||||
std::string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);
|
||||
|
@ -69,7 +69,7 @@ Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {
|
||||
@autoreleasepool {
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
|
||||
// etc match the earlier created MPSGraph
|
||||
string key = "eye_out_mps:" + getTensorsStringKey({result});
|
||||
std::string key = "eye_out_mps:" + getTensorsStringKey({result});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
|
||||
MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f
|
||||
shape:getMPSShape(result)
|
||||
|
@ -71,8 +71,8 @@ static void grid_sampler_2d_mps_impl(Tensor& output,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) +
|
||||
":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
|
||||
std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" +
|
||||
std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
|
@ -279,7 +279,7 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "nonzero_out_native_mps" + getTensorsStringKey(self);
|
||||
std::string key = "nonzero_out_native_mps" + getTensorsStringKey(self);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -371,7 +371,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "nonzero_out_native_mps" + getTensorsStringKey(self);
|
||||
std::string key = "nonzero_out_native_mps" + getTensorsStringKey(self);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -452,7 +452,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
||||
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types
|
||||
// etc match the earlier created MPSGraph
|
||||
string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + string([ns_dims_key UTF8String]);
|
||||
std::string key = "flip_mps:" + getTensorsStringKey({self}) + ":" + std::string([ns_dims_key UTF8String]);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor axes:ns_dims name:nil];
|
||||
@ -500,7 +500,7 @@ TORCH_IMPL_FUNC(index_add_mps_out)
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim);
|
||||
std::string key = "index_add_mps_out" + getTensorsStringKey({self, index, source}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
@ -649,7 +649,7 @@ Tensor& index_select_out_mps(const Tensor& self, int64_t dim, const Tensor& inde
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
|
||||
std::string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self));
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
@ -786,8 +786,9 @@ Tensor embedding_dense_backward_mps(const Tensor& grad_,
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "edb_mps:" + getTensorsStringKey({grad_, indices}) + ":num_weights" + std::to_string(num_weights) +
|
||||
":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq);
|
||||
std::string key = "edb_mps:" + getTensorsStringKey({grad_, indices}) + ":num_weights" +
|
||||
std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" +
|
||||
std::to_string(scale_grad_by_freq);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* incomingGradTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_));
|
||||
|
||||
@ -926,7 +927,8 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te
|
||||
auto expanded_source = source.expand(source_shape);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "index_fill_mps_" + getTensorsStringKey({self, index, expanded_source}) + ":" + std::to_string(dim);
|
||||
std::string key =
|
||||
"index_fill_mps_" + getTensorsStringKey({self, index, expanded_source}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(self));
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
|
@ -24,7 +24,7 @@ TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Te
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
@autoreleasepool {
|
||||
string key = "lerp_Tensor_mps" + getTensorsStringKey({self, end, weight});
|
||||
std::string key = "lerp_Tensor_mps" + getTensorsStringKey({self, end, weight});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto graph) {
|
||||
auto selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
auto endTensor = mpsGraphRankedPlaceHolder(mpsGraph, end);
|
||||
|
@ -65,7 +65,7 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
|
||||
std::string key = "mps_linear" + getTensorsStringKey({input, weight, bias});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight);
|
||||
@ -154,7 +154,7 @@ static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& g
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
|
||||
std::string key = "mps_linear_backward_input" + getTensorsStringKey({grad_output, weight_reshaped});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto* mpsGraph, auto* newCachedGraph) {
|
||||
newCachedGraph->weightTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, weight_reshaped);
|
||||
newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -236,7 +236,7 @@ static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& gra
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
|
||||
std::string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
|
||||
|
@ -495,7 +495,7 @@ static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor&
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
|
||||
std::string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
std::tie(newCachedGraph->inputTensor_, newCachedGraph->otherTensor_, newCachedGraph->outputTensor_) =
|
||||
@ -579,7 +579,7 @@ static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
std::string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" +
|
||||
std::to_string(alpha.toDouble());
|
||||
|
||||
@ -682,7 +682,7 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
|
||||
std::string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
|
||||
@ -923,7 +923,7 @@ static Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tens
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}, true, /*exclude_shape*/ true) +
|
||||
std::string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}, true, /*exclude_shape*/ true) +
|
||||
std::to_string(doTranspose);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -1260,7 +1260,7 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" +
|
||||
std::string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
|
||||
|
@ -23,7 +23,7 @@
|
||||
namespace at::native {
|
||||
namespace mps {
|
||||
|
||||
static string reductionToString(int64_t reduction) {
|
||||
static std::string reductionToString(int64_t reduction) {
|
||||
switch (reduction) {
|
||||
case Reduction::Mean:
|
||||
return "Mean";
|
||||
@ -58,7 +58,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output,
|
||||
const Tensor& target,
|
||||
int64_t reduction,
|
||||
Tensor& grad_input,
|
||||
const string op_name) {
|
||||
const std::string& op_name) {
|
||||
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
|
||||
auto norm = reduction == Reduction::Mean ? 2. / static_cast<double>(input.numel()) : 2.;
|
||||
|
||||
@ -73,7 +73,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) +
|
||||
std::string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) +
|
||||
getTensorsStringKey({input, target, grad_output});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -200,7 +200,7 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
int64_t reduction,
|
||||
Tensor& loss,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
const string op_name) {
|
||||
const std::string& op_name) {
|
||||
// TODO: add sanity check for the elements of input tensor to be within [0..1]
|
||||
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
|
||||
|
||||
@ -217,7 +217,7 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
Tensor target_squeezed = target.squeeze();
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
std::string key =
|
||||
op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -332,7 +332,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
grad_output = grad_output_arg.unsqueeze(channel_dim);
|
||||
}
|
||||
@autoreleasepool {
|
||||
string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) +
|
||||
std::string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) +
|
||||
std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction);
|
||||
|
||||
@ -483,9 +483,10 @@ static void nllnd_loss_forward_impl(Tensor& output,
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
// TODO: Make the key
|
||||
string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
|
||||
getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
|
||||
std::string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" +
|
||||
std::to_string(isWeightsArrayValid) + ":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] +
|
||||
":" + getMPSTypeString(input) + ":" + getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" +
|
||||
getMPSTypeString(weight);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape);
|
||||
@ -623,7 +624,7 @@ static void smooth_l1_loss_impl(const Tensor& input,
|
||||
MPSShape* input_shape = getMPSShape(input);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
// smooth_l1_loss_mps:
|
||||
@ -762,7 +763,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" +
|
||||
std::string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" +
|
||||
reductionToString(reduction) + ":" + std::to_string(beta);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -824,7 +825,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
// HuberLoss
|
||||
|
||||
Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) {
|
||||
string op_name = __func__;
|
||||
std::string op_name = __func__;
|
||||
using namespace mps;
|
||||
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
|
||||
TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes")
|
||||
@ -845,7 +846,7 @@ Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t re
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + ":" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" +
|
||||
std::string key = op_name + ":" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" +
|
||||
getTensorsStringKey({input, target});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -926,8 +927,8 @@ Tensor& huber_loss_backward_out_mps(const Tensor& grad_output,
|
||||
MPSShape* input_shape = getMPSShape(input);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "huber_loss_backward_out_mps:" + reductionToString(reduction) + ":" + std::to_string(delta) + ":" +
|
||||
[ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
std::string key = "huber_loss_backward_out_mps:" + reductionToString(reduction) + ":" + std::to_string(delta) +
|
||||
":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(new_grad_output), getMPSShape(new_grad_output));
|
||||
@ -1004,7 +1005,7 @@ Tensor& huber_loss_backward_out_mps(const Tensor& grad_output,
|
||||
|
||||
// MSELoss
|
||||
TORCH_IMPL_FUNC(mse_loss_out_mps)(const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& output_) {
|
||||
string op_name = "mse_loss_out_mps";
|
||||
std::string op_name = "mse_loss_out_mps";
|
||||
using namespace mps;
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output_.fill_(std::numeric_limits<float>::quiet_NaN()) : output_.zero_();
|
||||
@ -1029,7 +1030,7 @@ TORCH_IMPL_FUNC(mse_loss_out_mps)(const Tensor& input, const Tensor& target, int
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target});
|
||||
std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
|
@ -116,7 +116,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string mem_format_key;
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
@ -143,7 +143,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_mps_out(const Tensor& self,
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::string key = "batch_norm_mps_out:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::to_string(momentum) + ":" + std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" +
|
||||
std::to_string(has_weight) + ":" + std::to_string(has_bias) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
getTensorsStringKey({self,
|
||||
@ -489,8 +489,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> _batch_norm_legit_no_stats_mps_out(const T
|
||||
self, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, epsilon, output, save_mean, save_var);
|
||||
}
|
||||
|
||||
static string get_mem_string(c10::MemoryFormat memory_format) {
|
||||
string mem_format_key;
|
||||
static std::string get_mem_string(c10::MemoryFormat memory_format) {
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
@ -599,7 +599,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_ou
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string mem_format_key;
|
||||
std::string mem_format_key;
|
||||
switch (memory_format) {
|
||||
case at::MemoryFormat::Contiguous:
|
||||
mem_format_key = "Contiguous";
|
||||
@ -623,7 +623,7 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_mps(const Tensor& grad_ou
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::string key = "batch_norm_backward_mps:" + mem_format_key + ":" + std::to_string(epsilon) + ":" +
|
||||
std::to_string(train) + ":" + std::to_string(has_running_mean) + ":" + std::to_string(has_weight) + ":" +
|
||||
[ns_shape_key UTF8String] + ":" + c10::Join(",", grad_input_mask) + ":" + getMPSTypeString(input);
|
||||
auto input_mps_dtype = getMPSDataType(input);
|
||||
@ -1082,8 +1082,9 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_mps(const Tensor& grad_ou
|
||||
for (const auto i : c10::irange(num_normalized_dims))
|
||||
bn_gamma_shape[i + 2] = input_shape[i + num_channel_dims];
|
||||
|
||||
string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" + getArrayRefString(normalized_shape) +
|
||||
":" + getArrayRefString((*X).sizes()) + ":" + c10::Join(",", grad_input_mask) + ":" + getMPSTypeString(*X);
|
||||
std::string key = "layer_norm_backward_mps:" + std::to_string(has_weight) + ":" +
|
||||
getArrayRefString(normalized_shape) + ":" + getArrayRefString((*X).sizes()) + ":" +
|
||||
c10::Join(",", grad_input_mask) + ":" + getMPSTypeString(*X);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, *X);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, *dOut);
|
||||
|
@ -31,7 +31,7 @@ static Tensor& pad_out_template(Tensor& output,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
MPSGraphPaddingMode mode,
|
||||
double constantValue,
|
||||
const string op_name) {
|
||||
const std::string& op_name) {
|
||||
using CachedGraph = MPSUnaryGradCachedGraph;
|
||||
const int padding_size = (int)padding.size();
|
||||
int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
|
||||
@ -244,7 +244,7 @@ static Tensor& pad_out_template(Tensor& output,
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
|
||||
std::string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
|
||||
"]:" + std::to_string(constantValue);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -44,7 +44,7 @@ static Tensor pixel_shuffle_helper(const Tensor& self, int64_t factor, bool upsc
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (upscale ? "pixel_shuffle_" : "pixel_unshuffle_") + getTensorsStringKey({self}) + "_factor_" +
|
||||
std::string key = (upscale ? "pixel_shuffle_" : "pixel_unshuffle_") + getTensorsStringKey({self}) + "_factor_" +
|
||||
std::to_string(factor);
|
||||
CachedGraph* cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
const auto ndims = self.ndimension();
|
||||
|
@ -19,7 +19,7 @@ static void addc_mul_div_out_mps(const Tensor& self,
|
||||
const Scalar& value_opt, // default value = 1.0
|
||||
const Tensor& output,
|
||||
const bool is_div,
|
||||
const string op_name) {
|
||||
const std::string& op_name) {
|
||||
if (value_opt.toDouble() == 0.0) {
|
||||
output.copy_(self);
|
||||
return;
|
||||
@ -44,7 +44,7 @@ static void addc_mul_div_out_mps(const Tensor& self,
|
||||
output_ = at::empty_like(self, MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
string key = op_name + getTensorsStringKey({self, tensor1, tensor2});
|
||||
std::string key = op_name + getTensorsStringKey({self, tensor1, tensor2});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
ScalarType common_dtype =
|
||||
|
@ -45,7 +45,7 @@ static void pool2d_template(const Tensor& input,
|
||||
bool count_include_pad,
|
||||
const std::optional<int64_t> divisor_override,
|
||||
PoolingOpBlock poolingBlock,
|
||||
const c10::string& op_name) {
|
||||
const std::string& op_name) {
|
||||
const int64_t ndims = input.ndimension();
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
|
||||
@ -140,10 +140,10 @@ static void pool2d_template(const Tensor& input,
|
||||
padH = padW = 0;
|
||||
}
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" + getArrayRefString(kernel_size) +
|
||||
"]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) + "]:D[" +
|
||||
getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") + (count_include_pad ? ":include_pad" : "") +
|
||||
(has_divisor ? ":divisor" : "") + ":" +
|
||||
std::string key = op_name + getTensorsStringKey({input, indices, grad_output}) + ":K[" +
|
||||
getArrayRefString(kernel_size) + "]:S[" + getArrayRefString(stride) + "]:P[" + getArrayRefString(padding) +
|
||||
"]:D[" + getArrayRefString(dilation) + "]" + (ceil_mode ? ":ceil" : "") +
|
||||
(count_include_pad ? ":include_pad" : "") + (has_divisor ? ":divisor" : "") + ":" +
|
||||
(suggested_memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
|
||||
MPSShape* inputShape = getMPSShape(input, memory_format);
|
||||
@ -250,7 +250,7 @@ static void avg_pool2d_template(const Tensor& input,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
const std::optional<int64_t> divisor_override,
|
||||
const c10::string& op_name) {
|
||||
const std::string& op_name) {
|
||||
const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt));
|
||||
const bool is_backward_pass = grad_output.defined();
|
||||
const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0;
|
||||
|
@ -106,7 +106,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
std::string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -173,7 +173,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
std::string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -221,7 +221,7 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
bool start_less_end = (start.to<double>() <= end.to<double>());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) +
|
||||
std::string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) +
|
||||
std::to_string(start_less_end);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
|
||||
|
@ -225,7 +225,7 @@ static void reduction_out_mps(const Tensor& input_t,
|
||||
@autoreleasepool {
|
||||
std::string dtype_str = dtype.has_value() ? getMPSTypeString(dtype.value()) : "";
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = func_name + ":" + string([ns_key UTF8String]) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
std::string key = func_name + ":" + std::string([ns_key UTF8String]) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
std::to_string(keepdim) + ":" + std::to_string(reduction_type) + ":" + getTensorsStringKey(output_t) + ":" +
|
||||
dtype_str;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
@ -365,10 +365,10 @@ static void impl_func_norm_mps(const Tensor& input_tensor,
|
||||
auto stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t});
|
||||
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" +
|
||||
keepdim_info + ":" + toString(in_dtype);
|
||||
std::string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
std::string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t});
|
||||
std::string key = std::string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) +
|
||||
":" + keepdim_info + ":" + toString(in_dtype);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
@ -456,7 +456,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
IntArrayRef dim_value = use_dim ? dim.value() : NULL;
|
||||
|
||||
if (use_dim) {
|
||||
string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
std::string errMessage = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
errMessage += ": reduction dim must be in the range of input shape";
|
||||
for (const auto dim : dim_value) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim, num_input_dims);
|
||||
@ -576,13 +576,13 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
std::string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0";
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" +
|
||||
string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
|
||||
std::string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
|
||||
std::string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0";
|
||||
std::string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
std::string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" +
|
||||
std::string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
@ -708,7 +708,7 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + getTensorsStringKey(input_t);
|
||||
std::string key = func_name + getTensorsStringKey(input_t);
|
||||
CachedGraph* cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
@ -785,7 +785,7 @@ static void min_max_out_mps(const Tensor& input_t,
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
|
||||
std::string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -943,8 +943,8 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key =
|
||||
func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
|
||||
std::string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
std::string([ns_key UTF8String]);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputScalarType = input_t.scalar_type();
|
||||
MPSGraphTensor* inputTensor =
|
||||
@ -1299,7 +1299,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + "_out_mps:" + getTensorsStringKey(input_t) + ":" + std::to_string(dim_);
|
||||
std::string key = op_name + "_out_mps:" + getTensorsStringKey(input_t) + ":" + std::to_string(dim_);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
||||
@ -1373,7 +1373,7 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
|
||||
|
||||
@autoreleasepool {
|
||||
string key = string("any_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
@ -1424,7 +1424,7 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
|
||||
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
|
||||
|
||||
@autoreleasepool {
|
||||
string key = string("all_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
|
||||
@ -1581,7 +1581,7 @@ static void median_out_mps_common(const Tensor& input_t,
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
std::string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
getTensorsStringKey(indices);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
|
@ -38,7 +38,7 @@ void renorm_out_mps(const Tensor& self, const Scalar& p, int64_t dim, const Scal
|
||||
id<MTLBuffer> normBuffer = getMTLBufferStorage(norm);
|
||||
id<MTLBuffer> factorBuffer = getMTLBufferStorage(factor);
|
||||
|
||||
string key = "renorm_" + scalarToMetalTypeString(self);
|
||||
std::string key = "renorm_" + scalarToMetalTypeString(self);
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
id<MTLComputePipelineState> renormPSO = lib.getPipelineStateForFunc(key);
|
||||
|
@ -64,7 +64,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
||||
auto outputDataType = getMPSDataType(result);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
|
||||
std::string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor withMultiplier:getMPSShape(repeats) name:nil];
|
||||
|
@ -131,7 +131,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tenso
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" +
|
||||
std::string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" +
|
||||
std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" +
|
||||
std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" +
|
||||
std::to_string(batch_first);
|
||||
@ -408,10 +408,10 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
||||
// Get stream
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
@autoreleasepool {
|
||||
string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy}) +
|
||||
getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" +
|
||||
std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" +
|
||||
std::to_string(batch_first);
|
||||
std::string key = "lstm_backward_" +
|
||||
getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy}) + getMPSTypeString(input) +
|
||||
"_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) +
|
||||
"_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
NSMutableArray<MPSGraphTensor*>* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*>* recurrentKernelWeightsList =
|
||||
|
@ -60,7 +60,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
||||
if (output_type == MPSDataTypeUInt8) {
|
||||
output_type = MPSDataTypeInt8;
|
||||
}
|
||||
string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
|
||||
std::string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self));
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
@ -132,7 +132,7 @@ static void scatter_mps_general(const Tensor& self_arg,
|
||||
const Tensor& index,
|
||||
const Tensor& src,
|
||||
const Tensor& output,
|
||||
string func_name,
|
||||
std::string func_name,
|
||||
const std::string_view reduce) {
|
||||
using namespace mps;
|
||||
|
||||
@ -189,7 +189,7 @@ static void scatter_mps_general(const Tensor& self_arg,
|
||||
needsCast = true;
|
||||
}
|
||||
|
||||
string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
|
||||
std::string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" +
|
||||
std::string(reduce);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
@ -98,8 +98,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) +
|
||||
":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest);
|
||||
std::string key = std::string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" +
|
||||
std::to_string(k) + ":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
@ -253,7 +253,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
|
@ -63,7 +63,7 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string mem_format_key = get_mem_format_string(memory_format);
|
||||
std::string mem_format_key = get_mem_format_string(memory_format);
|
||||
MPSShape* input_shape_readonly = mps::getMPSShape(input);
|
||||
int num_input_dims = [input_shape_readonly count];
|
||||
// Check - Channels last implies 4d
|
||||
@ -93,8 +93,8 @@ TORCH_IMPL_FUNC(softmax_mps_out)
|
||||
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_mps_out" + getTensorsStringKey(input, true, /*exclude_shape*/ true) + ":" + mem_format_key +
|
||||
":" + std::to_string(dim_);
|
||||
std::string key = "softmax_mps_out" + getTensorsStringKey(input, true, /*exclude_shape*/ true) + ":" +
|
||||
mem_format_key + ":" + std::to_string(dim_);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()));
|
||||
@ -159,7 +159,7 @@ TORCH_IMPL_FUNC(softmax_backward_mps_out)
|
||||
MPSShape* grad_shape = mps::getMPSShape(grad);
|
||||
NSString* ns_shape_key = [[grad_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::string key = "softmax_backward_mps_out:" + getMPSTypeString(output) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
std::to_string(dim_);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* softmaxTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(output), grad_shape);
|
||||
|
@ -50,7 +50,7 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" +
|
||||
std::string key = std::string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" +
|
||||
std::to_string(dim) + ":descending" + std::to_string(descending);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
@ -19,7 +19,7 @@ static Tensor& bincount_mps_impl(const Tensor& self, const Tensor& weights, Tens
|
||||
bool has_weights = weights.defined();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
|
||||
std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(output.scalar_type()));
|
||||
|
@ -80,7 +80,7 @@ static void clamp_mps_graph(CachedGraph* cachedGraph,
|
||||
cachedGraph->outputTensor = outputTensor;
|
||||
}
|
||||
|
||||
static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) {
|
||||
static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, std::string op_name) {
|
||||
if (!clamp_opt->is_same_size(input_t)) {
|
||||
auto num_clamp_dims = clamp_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
@ -119,7 +119,7 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
const OptionalTensorRef min_opt,
|
||||
const OptionalTensorRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name) {
|
||||
std::string op_name) {
|
||||
const bool has_min = (min_opt.has_value() && min_opt->defined());
|
||||
const bool has_max = (max_opt.has_value() && max_opt->defined());
|
||||
|
||||
@ -173,7 +173,7 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
: getTensorsStringKey({input_t, min_opt_tensor}))
|
||||
: (has_max ? getTensorsStringKey({input_t, max_opt_tensor}) : getTensorsStringKey({input_t}));
|
||||
|
||||
string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key;
|
||||
std::string key = op_name + (has_min ? "_min" : "") + (has_max ? "_max" : "") + "_tensor" + tensor_key;
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min) {
|
||||
newCachedGraph->minTensor = mpsGraphRankedPlaceHolder(mpsGraph, min_opt_tensor);
|
||||
@ -221,7 +221,7 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
const OptionalScalarRef min_opt,
|
||||
const OptionalScalarRef max_opt,
|
||||
const Tensor& output_t,
|
||||
string op_name) {
|
||||
std::string op_name) {
|
||||
using scalar_t = double;
|
||||
|
||||
const bool has_min = (min_opt.has_value());
|
||||
@ -243,7 +243,7 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
@ -278,7 +278,7 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
|
||||
bool assume_unique,
|
||||
bool invert,
|
||||
const Tensor& out,
|
||||
string op_name) {
|
||||
std::string op_name) {
|
||||
if (elements.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
@ -301,7 +301,7 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
|
||||
common_type);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);
|
||||
std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(elements.scalar_type()));
|
||||
@ -440,7 +440,7 @@ static void where_kernel_mps(TensorIterator& iter) {
|
||||
MPSDataType otherDataType = getMPSScalarType(other.scalar_type());
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other});
|
||||
std::string key = "where_self_out_mps:" + getTensorsStringKey({cond_bool, self, other});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* conditionTensor = mpsGraphRankedPlaceHolder(mpsGraph, conditionDataType, getMPSShape(cond_bool));
|
||||
@ -508,7 +508,7 @@ Tensor& nan_to_num_out_mps(const Tensor& self,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "nan_to_num" + getTensorsStringKey({self});
|
||||
std::string key = "nan_to_num" + getTensorsStringKey({self});
|
||||
MPSDataType self_dtype = getMPSScalarType(self.scalar_type());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -37,7 +37,7 @@ TORCH_IMPL_FUNC(triu_mps_out)
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
std::string key = "triu_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -85,7 +85,7 @@ TORCH_IMPL_FUNC(tril_mps_out)
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
std::string key = "tril_mps_out" + mps::getTensorsStringKey({self}) + ":" + std::to_string(k);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
|
||||
|
@ -82,7 +82,7 @@ static void unary_op_noresize(const Tensor& self, const Tensor& output_, std::st
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, output});
|
||||
std::string key = op_name + getTensorsStringKey({self, output});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<MPSUnaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* castTensor = newCachedGraph->inputTensor_;
|
||||
@ -337,7 +337,7 @@ TORCH_IMPL_FUNC(expm1_out_mps)(const Tensor& self, const Tensor& output) {
|
||||
});
|
||||
}
|
||||
|
||||
static void logit_mps_impl(const Tensor& self, std::optional<double> eps, Tensor& output, const std::string op_name) {
|
||||
static void logit_mps_impl(const Tensor& self, std::optional<double> eps, Tensor& output, const std::string& op_name) {
|
||||
std::string key = op_name + ":[" + (eps.has_value() ? std::to_string(eps.value()) : "NULL") + "]";
|
||||
|
||||
mps::unary_op(self, output, key, ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
|
||||
|
@ -184,7 +184,7 @@ static UniqueCachedGraph* getUniqueGraph(const Tensor& self,
|
||||
const bool consecutive,
|
||||
std::optional<int64_t> dim) {
|
||||
@autoreleasepool {
|
||||
string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim);
|
||||
std::string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim);
|
||||
return LookUpOrCreateCachedGraph<UniqueCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self), getMPSShape(self));
|
||||
auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim);
|
||||
|
@ -103,7 +103,7 @@ static void upsample_out_template(const Tensor& input,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
|
||||
std::string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
|
||||
getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" +
|
||||
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
|
||||
|
||||
@ -259,7 +259,7 @@ static void upsample_kernel_out_template(const Tensor& input,
|
||||
std::optional<double> scale_h_opt,
|
||||
std::optional<double> scale_w_opt,
|
||||
const Tensor& output,
|
||||
const std::string name) {
|
||||
const std::string& name) {
|
||||
if (output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ std::tuple<Tensor, Tensor> weight_norm_mps(const Tensor& v, const Tensor& g, int
|
||||
auto w = at::empty_like(v, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
auto norms = at::empty_like(g, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
string key = "weight_norm_mps_" + std::to_string(dim) + getTensorsStringKey({v, g});
|
||||
std::string key = "weight_norm_mps_" + std::to_string(dim) + getTensorsStringKey({v, g});
|
||||
|
||||
NSMutableArray* reduction_dims = [NSMutableArray array];
|
||||
for (int i = 0; i < v.dim(); ++i) {
|
||||
@ -101,7 +101,7 @@ std::tuple<Tensor, Tensor> weight_norm_backward_mps(const Tensor& grad_w,
|
||||
auto grad_v = at::empty_like(saved_v, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
auto grad_g = at::empty_like(saved_g, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
string key =
|
||||
std::string key =
|
||||
"weight_norm_backward_mps_" + std::to_string(dim) + getTensorsStringKey({grad_w, saved_v, saved_g, saved_norms});
|
||||
|
||||
NSMutableArray* reduction_dims = [NSMutableArray array];
|
||||
|
@ -256,7 +256,7 @@ enum xnn_status xnnp_define_q_tensor(const Tensor& tensor, MemoryFormat format,
|
||||
|
||||
template <typename scalar_t, bool ReLUFused = false>
|
||||
Tensor xnnp_add(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
|
||||
const string func_name = "xnnp_add()";
|
||||
const std::string_view func_name = "xnnp_add()";
|
||||
TORCH_CHECK(qa.ndimension() > 0, func_name, ": Got empty input tensor.");
|
||||
TORCH_CHECK(at::native::xnnpack::available(), func_name, ": XNNPACK is not available")
|
||||
|
||||
|
@ -81,7 +81,7 @@ Tensor _mul_out_xnnpack(
|
||||
const Tensor& other,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
const string func_name = "xnnp_mul()";
|
||||
const std::string_view func_name = "xnnp_mul()";
|
||||
TORCH_CHECK(self.ndimension() > 0, func_name, ": Got empty input tensor.");
|
||||
TORCH_CHECK(
|
||||
at::native::xnnpack::available(), func_name, ": XNNPACK is not available")
|
||||
|
@ -233,7 +233,7 @@ static Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, c
|
||||
}
|
||||
|
||||
template <typename op_t>
|
||||
static Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
|
||||
static Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const std::string& op_name, const op_t& op) {
|
||||
// NOTE: intersection_binary_op_with_wrapped_scalar_ assumes scalar.numel() == 1.
|
||||
const auto broadcasted_shape = infer_size(sparse.sizes(), scalar.sizes());
|
||||
if (sparse.sizes() != broadcasted_shape) {
|
||||
|
@ -2,9 +2,9 @@
|
||||
#include <c10/util/TypeIndex.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using c10::string_view;
|
||||
using c10::util::get_fully_qualified_type_name;
|
||||
using c10::util::get_type_index;
|
||||
using std::string_view;
|
||||
|
||||
// NOLINTBEGIN(modernize-unary-static-assert)
|
||||
namespace {
|
||||
|
@ -37,7 +37,7 @@ void SetStackTraceFetcher(std::function<::c10::Backtrace()> fetcher) {
|
||||
GetFetchStackTrace() = std::move(fetcher);
|
||||
}
|
||||
|
||||
void SetStackTraceFetcher(std::function<string()> fetcher) {
|
||||
void SetStackTraceFetcher(std::function<std::string()> fetcher) {
|
||||
SetStackTraceFetcher([fetcher = std::move(fetcher)] {
|
||||
return std::make_shared<PrecomputedLazyValue<std::string>>(fetcher());
|
||||
});
|
||||
@ -125,14 +125,14 @@ bool IsAPIUsageDebugMode() {
|
||||
return val.has_value() && !val.value().empty(); // any non-empty value
|
||||
}
|
||||
|
||||
void APIUsageDebug(const string& event) {
|
||||
void APIUsageDebug(const std::string& event) {
|
||||
// use stderr to avoid messing with glog
|
||||
std::cerr << "PYTORCH_API_USAGE " << event << '\n';
|
||||
}
|
||||
|
||||
APIUsageLoggerType* GetAPIUsageLogger() {
|
||||
static APIUsageLoggerType func =
|
||||
IsAPIUsageDebugMode() ? &APIUsageDebug : [](const string&) {};
|
||||
IsAPIUsageDebugMode() ? &APIUsageDebug : [](const std::string&) {};
|
||||
return &func;
|
||||
}
|
||||
|
||||
|
@ -34,26 +34,26 @@ namespace detail {
|
||||
template <typename T>
|
||||
inline constexpr c10::c10_string_view fully_qualified_type_name_impl() {
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
constexpr c10::string_view fun_sig = __FUNCSIG__;
|
||||
constexpr std::string_view fun_sig = __FUNCSIG__;
|
||||
#if defined(__NVCC__)
|
||||
constexpr c10::string_view prefix =
|
||||
constexpr std::string_view prefix =
|
||||
"c10::basic_string_view<char> c10::util::detail::fully_qualified_type_name_impl<";
|
||||
constexpr c10::string_view suffix = ">()";
|
||||
constexpr std::string_view suffix = ">()";
|
||||
#else
|
||||
constexpr c10::string_view prefix =
|
||||
constexpr std::string_view prefix =
|
||||
"class c10::basic_string_view<char> __cdecl c10::util::detail::fully_qualified_type_name_impl<";
|
||||
constexpr c10::string_view suffix = ">(void)";
|
||||
constexpr std::string_view suffix = ">(void)";
|
||||
#endif
|
||||
#elif defined(__clang__)
|
||||
constexpr c10::string_view fun_sig = __PRETTY_FUNCTION__;
|
||||
constexpr c10::string_view prefix =
|
||||
constexpr std::string_view fun_sig = __PRETTY_FUNCTION__;
|
||||
constexpr std::string_view prefix =
|
||||
"c10::c10_string_view c10::util::detail::fully_qualified_type_name_impl() [T = ";
|
||||
constexpr c10::string_view suffix = "]";
|
||||
constexpr std::string_view suffix = "]";
|
||||
#elif defined(__GNUC__)
|
||||
constexpr c10::string_view fun_sig = __PRETTY_FUNCTION__;
|
||||
constexpr c10::string_view prefix =
|
||||
constexpr std::string_view fun_sig = __PRETTY_FUNCTION__;
|
||||
constexpr std::string_view prefix =
|
||||
"constexpr c10::c10_string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ";
|
||||
constexpr c10::string_view suffix =
|
||||
constexpr std::string_view suffix =
|
||||
"; c10::c10_string_view = c10::basic_string_view<char>]";
|
||||
#endif
|
||||
#if !defined(__CUDA_ARCH__) && !defined(__CUDA_ARCH_LIST__)
|
||||
|
@ -19,7 +19,6 @@
|
||||
#include <c10/util/TypeIndex.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
|
@ -370,7 +370,7 @@ inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) {
|
||||
|
||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
||||
namespace detail {
|
||||
inline Tensor gelu(const Tensor& input, const string& approximate) {
|
||||
inline Tensor gelu(const Tensor& input, const std::string& approximate) {
|
||||
return torch::gelu(input, approximate);
|
||||
}
|
||||
} // namespace detail
|
||||
|
@ -852,7 +852,7 @@ void RNNCellImplBase<Derived>::pretty_print(std::ostream& stream) const {
|
||||
template <typename Derived>
|
||||
void RNNCellImplBase<Derived>::check_forward_input(
|
||||
const Tensor& input,
|
||||
const string& name) const {
|
||||
const std::string& name) const {
|
||||
TORCH_CHECK(
|
||||
input.dim() == 1 || input.dim() == 2,
|
||||
"Expected ",
|
||||
|
@ -2,7 +2,6 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/ApproximateClock.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Types.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Utils.hpp>
|
||||
|
@ -1,7 +1,6 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/distributed/c10d/FileStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Functional.hpp>
|
||||
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
|
||||
@ -10,6 +9,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp>
|
||||
#include <torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#ifndef _WIN32
|
||||
|
@ -146,10 +146,10 @@ void removeCurrentName(
|
||||
store.set(allWorkerInfosKey, newAllWorkerInfosVector);
|
||||
}
|
||||
|
||||
const string storeKeyBarrierId = "_ID_";
|
||||
const string storeKeyProcessCount = "PROCESS_COUNT";
|
||||
const string storeKeyActiveCallCount = "ACTIVE_CALLS";
|
||||
const string storeKeyReady = "READY";
|
||||
constexpr const auto storeKeyBarrierId = "_ID_";
|
||||
constexpr const auto storeKeyProcessCount = "PROCESS_COUNT";
|
||||
constexpr const auto storeKeyActiveCallCount = "ACTIVE_CALLS";
|
||||
constexpr const auto storeKeyReady = "READY";
|
||||
static std::atomic<int> barrierId(0);
|
||||
|
||||
static std::tuple<std::string, std::string, std::string> getNextKeyIds() {
|
||||
|
@ -527,7 +527,7 @@ void AOTIModelPackageLoader::load_constants(
|
||||
bool user_managed) {
|
||||
std::unordered_map<std::string, std::string> constant_name_to_fqn =
|
||||
runner_->getConstantNamesToOriginalFQNs();
|
||||
std::unordered_map<std::string, at::string> fqn_to_constant_name;
|
||||
std::unordered_map<std::string, std::string> fqn_to_constant_name;
|
||||
for (const auto& it : constant_name_to_fqn) {
|
||||
fqn_to_constant_name.emplace(it.second, it.first);
|
||||
}
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/type_factory.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/jit/frontend/parser_constants.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <string_view>
|
||||
|
||||
using torch::jit::valid_single_char_tokens;
|
||||
|
||||
|
@ -63,8 +63,7 @@ void flatten_rec(PyObject* obj, ParsedArgs& args) {
|
||||
structure.push_back(D::DictClose);
|
||||
Py_DECREF(dict_items);
|
||||
} else if (THPUtils_checkString(obj)) {
|
||||
string str = THPUtils_unpackString(obj);
|
||||
args.desc.strings.emplace_back(str);
|
||||
args.desc.strings.emplace_back(THPUtils_unpackString(obj));
|
||||
args.desc.structure.push_back(D::String);
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
auto& var = THPVariable_Unpack(obj);
|
||||
@ -142,8 +141,8 @@ py::object unflatten_rec(
|
||||
ArrayRef<Variable>::iterator& var_it,
|
||||
ArrayRef<Variable>::iterator& var_it_end,
|
||||
std::string::const_iterator& desc_it,
|
||||
std::vector<string>::const_iterator& str_it,
|
||||
std::vector<string>::const_iterator& str_it_end) {
|
||||
std::vector<std::string>::const_iterator& str_it,
|
||||
std::vector<std::string>::const_iterator& str_it_end) {
|
||||
char type = *desc_it++;
|
||||
if (type == D::TupleOpen) {
|
||||
std::vector<py::object> objs;
|
||||
|
@ -102,7 +102,7 @@ void prepare_and_call_rpc_op(
|
||||
std::vector<std::string> names;
|
||||
for (const auto& entry : kwargsDict) {
|
||||
const IValue& keyIValue = entry.key();
|
||||
const string& keyStr = keyIValue.toStringRef();
|
||||
const std::string& keyStr = keyIValue.toStringRef();
|
||||
names.emplace_back(keyStr);
|
||||
}
|
||||
throw std::runtime_error(functionSchema.findErrorInKwargs(names));
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/core/qualified_name.h>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -11,7 +12,6 @@
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/FbcodeMaps.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/string_view.h>
|
||||
#include <torch/csrc/Export.h>
|
||||
|
||||
namespace torch::jit {
|
||||
|
@ -103,7 +103,7 @@ void initLazyBindings(PyObject* module) {
|
||||
|
||||
lazy.def(
|
||||
"_mark_step",
|
||||
// TODO(whc) this API should probably change from vector<string> to
|
||||
// TODO(whc) this API should probably change from vector<std::string> to
|
||||
// vector<c10::device> but in a separate PR
|
||||
[](const std::string& device_str,
|
||||
const std::vector<std::string>& devices,
|
||||
|
Reference in New Issue
Block a user