mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
Compare commits
49 Commits
Author | SHA1 | Date | |
---|---|---|---|
9538bf4e7c | |||
219da29dfd | |||
fb013ecb24 | |||
6af4c6acad | |||
786c24a4cd | |||
5d8c7f39d4 | |||
c9c1fed065 | |||
94fea82d66 | |||
447173198b | |||
b79d056e76 | |||
eb567b1f40 | |||
1dd2431f86 | |||
5fcb5f0c8b | |||
a55d0d9718 | |||
8c1247cffb | |||
70a1e85718 | |||
adb699189b | |||
45dccfddcd | |||
3e09123797 | |||
61f922c2ca | |||
984b1a8c35 | |||
205410cb44 | |||
cac7a22b92 | |||
8a09940a54 | |||
1d233b8f50 | |||
491c4a5dcb | |||
4345d98663 | |||
a838e90964 | |||
29081059b6 | |||
f8c45996d5 | |||
c13e03c874 | |||
053930e194 | |||
9a38cae299 | |||
55901fb3da | |||
fc77fdca6f | |||
648625b230 | |||
207c2248a8 | |||
a206dcc79e | |||
f2d7f235a6 | |||
402b289f3b | |||
a32157c67c | |||
24e7f29099 | |||
5b5d269d34 | |||
fa88f390a0 | |||
fe39c07826 | |||
cba195c8ed | |||
16e67be7f1 | |||
7afffdf48b | |||
ca45649eb5 |
@ -1099,7 +1099,6 @@ exclude_patterns = [
|
||||
'test/test_namedtuple_return_api.py',
|
||||
'test/test_native_functions.py',
|
||||
'test/test_native_mha.py',
|
||||
'test/test_nestedtensor.py',
|
||||
'test/test_nn.py',
|
||||
'test/test_out_dtype_op.py',
|
||||
'test/test_overrides.py',
|
||||
|
@ -462,7 +462,7 @@ inline Tensor _sum_to(
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
|
||||
if (shape[i - leading_dims] == 1 &&
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
|
||||
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
|
||||
reduce_dims.push_back(i);
|
||||
}
|
||||
|
@ -478,8 +478,6 @@ namespace impl {
|
||||
// (maybe except for some internal prim ops).
|
||||
using GenericList = List<IValue>;
|
||||
|
||||
const IValue* ptr_to_first_element(const GenericList& list);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -350,11 +350,4 @@ void List<T>::unsafeSetElementType(TypePtr t) {
|
||||
impl_->elementType = std::move(t);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
inline const IValue* ptr_to_first_element(const GenericList& list) {
|
||||
return &list.impl_->list[0];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1195,15 +1195,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
|
||||
#undef REPR
|
||||
}
|
||||
|
||||
static Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
|
||||
const optional<int64_t> win_lengthOpt, const Tensor& window,
|
||||
const bool center, const bool normalized, const optional<bool> onesidedOpt,
|
||||
const optional<int64_t> lengthOpt) {
|
||||
return at::native::istft(
|
||||
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
|
||||
onesidedOpt, lengthOpt, /*return_complex=*/false);
|
||||
}
|
||||
|
||||
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
|
||||
const auto input_sizes = input.sizes();
|
||||
const auto input_strides = input.strides();
|
||||
|
@ -210,7 +210,6 @@
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/StringUtil.h>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
@ -13,7 +13,8 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool isTraining,
|
||||
bool is_causal,
|
||||
@ -34,7 +35,8 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -128,7 +130,8 @@ struct MHAParams {
|
||||
int64_t h;
|
||||
int64_t s_q;
|
||||
int64_t s_kv;
|
||||
int64_t d;
|
||||
int64_t d_qk;
|
||||
int64_t d_v;
|
||||
double dropout_probability;
|
||||
bool is_causal;
|
||||
bool return_softmaxstats;
|
||||
@ -140,7 +143,8 @@ void setMHAParams(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
@ -155,7 +159,8 @@ void setMHAParams(
|
||||
}
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.d = d;
|
||||
params.d_qk = d_qk;
|
||||
params.d_v = d_v;
|
||||
params.s_q = s_q;
|
||||
params.s_kv = s_kv;
|
||||
params.dropout_probability = dropout_probability;
|
||||
@ -193,7 +198,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
@ -206,7 +212,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d,
|
||||
d_qk,
|
||||
d_v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -249,7 +256,8 @@ auto build_graph_and_tensors(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool return_softmaxstats,
|
||||
bool is_causal,
|
||||
@ -383,7 +391,8 @@ auto build_graph_and_tensors_backward(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -514,7 +523,8 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool return_softmaxstats,
|
||||
bool is_causal,
|
||||
@ -528,7 +538,7 @@ void run_cudnn_SDP_fprop(
|
||||
Tensor& dropoutoffset) {
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
o = at::empty_strided(
|
||||
{b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options());
|
||||
{b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options());
|
||||
if (return_softmaxstats) {
|
||||
// TODO(eqy): verify that this is correct
|
||||
softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
|
||||
@ -539,7 +549,8 @@ void run_cudnn_SDP_fprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d,
|
||||
d_qk,
|
||||
d_v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -556,7 +567,8 @@ void run_cudnn_SDP_fprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d,
|
||||
d_qk,
|
||||
d_v,
|
||||
scaling_factor,
|
||||
return_softmaxstats,
|
||||
is_causal,
|
||||
@ -599,7 +611,8 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_qk,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
@ -623,7 +636,18 @@ void run_cudnn_SDP_bprop(
|
||||
}
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
auto key = MHACacheKeyWrapper(
|
||||
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d_qk,
|
||||
d_v,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_probability,
|
||||
is_causal,
|
||||
true);
|
||||
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
|
||||
graph_and_tensors_backward graph_and_tensors_backward_values;
|
||||
if (graph_and_tensors_backward_ptr) {
|
||||
@ -634,7 +658,8 @@ void run_cudnn_SDP_bprop(
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d,
|
||||
d_qk,
|
||||
d_v,
|
||||
scaling_factor,
|
||||
is_causal,
|
||||
dropout_probability,
|
||||
@ -684,5 +709,4 @@ void run_cudnn_SDP_bprop(
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif
|
||||
|
@ -9,7 +9,8 @@ void run_cudnn_SDP_fprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_k,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool isTraining,
|
||||
bool is_causal,
|
||||
@ -27,7 +28,8 @@ void run_cudnn_SDP_bprop(
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
int64_t d_k,
|
||||
int64_t d_v,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
|
@ -18,26 +18,21 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
|
||||
/* coefficients in rational expansion */
|
||||
|
||||
float y_abs = abs(y);
|
||||
if(y_abs > 1.0f){{
|
||||
output[index] = NAN;
|
||||
if (y_abs >= 1.0f) {{
|
||||
output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
|
||||
return;
|
||||
}}
|
||||
if(y_abs == 1.0f){{
|
||||
output[index] = copysign(INFINITY, y);
|
||||
return;
|
||||
}}
|
||||
if(y_abs <= 0.7f) {{
|
||||
if (y_abs <= 0.7f) {{
|
||||
z = y * y;
|
||||
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
|
||||
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f);
|
||||
num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
|
||||
dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
|
||||
x = y * num / dem;
|
||||
}}
|
||||
else{{
|
||||
}} else {{
|
||||
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
|
||||
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
|
||||
dem = (d[1]*z + d[0])*z + 1.0f;
|
||||
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
|
||||
dem = (d[1] * z + d[0]) * z + 1.0f;
|
||||
x = copysign(num, y) / dem;
|
||||
}}
|
||||
|
||||
output[index] = x;
|
||||
}})METAL";
|
||||
output[index] = {0}(x);
|
||||
}})METAL";
|
||||
|
@ -143,7 +143,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
|
||||
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
|
||||
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -193,8 +193,8 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
|
||||
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
|
||||
string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
|
||||
std::to_string(negative_slope.to<double>());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -242,7 +242,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim);
|
||||
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
||||
@ -285,7 +285,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim);
|
||||
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
|
||||
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
|
||||
@ -539,8 +539,8 @@ TORCH_IMPL_FUNC(threshold_out_mps)
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to<double>()) + ":" +
|
||||
to_string(value.to<double>());
|
||||
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) +
|
||||
":" + std::to_string(value.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -587,7 +587,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to<double>());
|
||||
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to<double>()) + ":" +
|
||||
to_string(scale.to<double>()) + ":" + to_string(input_scale.to<double>());
|
||||
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
|
||||
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
@ -916,8 +916,8 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
|
||||
to_string(alpha.to<double>()) + ":" + to_string(scale.to<double>()) + ":" +
|
||||
to_string(input_scale.to<double>()) + ":" + to_string(is_result);
|
||||
std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" +
|
||||
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
@ -1010,7 +1010,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim);
|
||||
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor
|
||||
@ -1052,7 +1052,7 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim);
|
||||
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
|
||||
MPSGraphTensor* gradOutputTensor =
|
||||
@ -1855,8 +1855,8 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to<double>()) +
|
||||
":" + to_string(max.to<double>());
|
||||
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
|
||||
std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
@ -136,8 +136,8 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
|
||||
":" + to_string(alpha_.toDouble());
|
||||
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" +
|
||||
std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
|
||||
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
|
@ -33,7 +33,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
|
||||
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
|
||||
|
@ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
|
||||
|
||||
string bias_shape_key;
|
||||
if (bias_defined) {
|
||||
bias_shape_key = to_string(bias_shape[0]);
|
||||
bias_shape_key = std::to_string(bias_shape[0]);
|
||||
} else {
|
||||
bias_shape_key = "nobias";
|
||||
}
|
||||
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) +
|
||||
":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" +
|
||||
to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) +
|
||||
":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" +
|
||||
bias_shape_key;
|
||||
key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
|
||||
} else {
|
||||
key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
|
||||
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
|
||||
to_string(bias_defined) + ":" + bias_shape_key;
|
||||
key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
|
||||
}
|
||||
|
||||
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
|
||||
@ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" +
|
||||
to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) +
|
||||
":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" +
|
||||
to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" +
|
||||
string([ns_shape_key UTF8String]);
|
||||
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
|
||||
} else {
|
||||
key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
}
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
@ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
|
||||
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key;
|
||||
if (is3DConv) {
|
||||
key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" +
|
||||
to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
|
||||
to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
|
||||
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
|
||||
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
} else {
|
||||
key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
|
||||
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
|
||||
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
|
||||
key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
|
||||
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
|
||||
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
|
||||
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
|
||||
}
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -63,7 +63,7 @@ Tensor& random_mps_impl(Tensor& self,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" +
|
||||
to_string(val1) + ":" + to_string(val2);
|
||||
std::to_string(val1) + ":" + std::to_string(val2);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->stateTensor =
|
||||
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
|
||||
@ -469,7 +469,7 @@ static Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
|
||||
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSShape* prob_shape = getMPSShape(self_v);
|
||||
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);
|
||||
|
@ -236,7 +236,7 @@ static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& gra
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
|
||||
string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
|
||||
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);
|
||||
|
@ -229,8 +229,8 @@ static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
|
||||
to_string(alpha.toDouble());
|
||||
key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" +
|
||||
std::to_string(alpha.toDouble());
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -331,8 +331,8 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* selfTensor = nil;
|
||||
MPSGraphTensor* otherTensor = nil;
|
||||
@ -615,8 +615,8 @@ Tensor& addr_out_mps(const Tensor& self,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
|
||||
":" + to_string(alpha.toDouble());
|
||||
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" +
|
||||
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
|
||||
MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);
|
||||
|
@ -69,7 +69,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output,
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) +
|
||||
string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) +
|
||||
getTensorsStringKey({input, target, grad_output});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
}
|
||||
@autoreleasepool {
|
||||
string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) +
|
||||
to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
|
||||
to_string(isTargetCasted) + ":" + reductionToString(reduction);
|
||||
std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
@ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output,
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
// TODO: Make the key
|
||||
string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
|
||||
reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
|
||||
getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
|
||||
string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
|
||||
":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
|
||||
getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape);
|
||||
@ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input,
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
|
||||
string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" +
|
||||
to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
// smooth_l1_loss_mps:
|
||||
// ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta
|
||||
@ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" +
|
||||
reductionToString(reduction) + ":" + to_string(beta);
|
||||
reductionToString(reduction) + ":" + std::to_string(beta);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
|
@ -106,7 +106,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -173,7 +173,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
|
||||
auto stream = getCurrentMPSStream();
|
||||
auto mpsDataType = getMPSDataType(result);
|
||||
@autoreleasepool {
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
|
||||
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
if (!cachedGraph) {
|
||||
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
|
||||
@ -221,8 +221,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
|
||||
bool start_less_end = (start.to<double>() <= end.to<double>());
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
|
||||
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) +
|
||||
std::to_string(start_less_end);
|
||||
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
|
||||
|
||||
if (!cachedGraph) {
|
||||
|
@ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor,
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t});
|
||||
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" +
|
||||
keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData);
|
||||
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" +
|
||||
keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData);
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
|
||||
@ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
|
||||
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
|
||||
string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0";
|
||||
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
|
||||
string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" +
|
||||
string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
|
||||
@ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t,
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_);
|
||||
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* outputTensor = nil;
|
||||
@ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
|
||||
@autoreleasepool {
|
||||
NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key =
|
||||
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
|
||||
func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto inputScalarType = input_t.scalar_type();
|
||||
MPSGraphTensor* inputTensor =
|
||||
@ -1217,7 +1217,7 @@ TORCH_IMPL_FUNC(any_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
MPSShape* input_t_shape = getMPSShape(input_t);
|
||||
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
|
||||
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
|
||||
getMPSTypeString(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSDataType input_type = getMPSDataType(input_t);
|
||||
@ -1313,7 +1313,7 @@ TORCH_IMPL_FUNC(all_out_mps)
|
||||
|
||||
@autoreleasepool {
|
||||
MPSShape* input_t_shape = getMPSShape(input_t);
|
||||
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
|
||||
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
|
||||
getMPSTypeString(input_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSDataType input_type = getMPSDataType(input_t);
|
||||
@ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t,
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
|
||||
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
|
||||
getTensorsStringKey(indices_t);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* castInputTensor =
|
||||
|
@ -108,8 +108,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
|
||||
":dim" + to_string(dim_) + ":largest" + to_string(largest);
|
||||
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) +
|
||||
":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
@ -320,12 +320,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
string key =
|
||||
"cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
} else {
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size());
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
|
||||
}
|
||||
for (auto idx : skipped_tensor_indices) {
|
||||
key += "," + std::to_string(idx);
|
||||
|
@ -60,8 +60,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
|
||||
// Input as placeholders
|
||||
MPSShape* input_shape = getMPSShape(self);
|
||||
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
|
||||
":descending" + to_string(descending);
|
||||
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" +
|
||||
std::to_string(dim) + ":descending" + std::to_string(descending);
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
|
||||
|
||||
|
@ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
|
||||
|
||||
@autoreleasepool {
|
||||
// the optional min/max refs could affect how we build the cached graph
|
||||
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
|
||||
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
if (has_min)
|
||||
newCachedGraph->minTensor = [mpsGraph
|
||||
|
@ -13,32 +13,6 @@
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace at::native {
|
||||
static const std::string& getMetalType(const c10::ScalarType& t) {
|
||||
// Mapping from c10::ScalarType to integral type that can be used for unary ops
|
||||
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
|
||||
{c10::ScalarType::Half, "half"},
|
||||
{c10::ScalarType::Float, "float"},
|
||||
{c10::ScalarType::Long, "long"},
|
||||
{c10::ScalarType::Int, "int"},
|
||||
{c10::ScalarType::Short, "short"},
|
||||
{c10::ScalarType::Bool, "bool"},
|
||||
{c10::ScalarType::Char, "int8_t"},
|
||||
{c10::ScalarType::Byte, "uint8_t"},
|
||||
};
|
||||
|
||||
auto it = scalar_to_metal_type.find(t);
|
||||
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static const std::string& getMetalType(const c10::Scalar& s) {
|
||||
return getMetalType(s.type());
|
||||
}
|
||||
|
||||
static const std::string& getMetalType(const Tensor& t) {
|
||||
return getMetalType(t.scalar_type());
|
||||
}
|
||||
|
||||
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
|
||||
|
||||
TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
|
||||
@ -57,7 +31,8 @@ TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
|
||||
}
|
||||
using namespace mps;
|
||||
@autoreleasepool {
|
||||
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)});
|
||||
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel",
|
||||
{scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
|
||||
|
||||
if (!self.is_contiguous()) {
|
||||
inputTensor = inputTensor.contiguous();
|
||||
|
@ -36,8 +36,8 @@ static std::string getUniqueKey(const ScalarType& dtype,
|
||||
const bool consecutive,
|
||||
c10::optional<int64_t> dimOpt) {
|
||||
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
|
||||
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
|
||||
(dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" +
|
||||
std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]";
|
||||
}
|
||||
|
||||
// dim arg not supported when non consecutive, ie sorted
|
||||
|
@ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input,
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
|
||||
getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
|
||||
getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" +
|
||||
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
|
@ -42,7 +42,7 @@ static std::string getStridedKey(const ScalarType& self_dtype,
|
||||
}
|
||||
|
||||
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
|
||||
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]";
|
||||
}
|
||||
|
||||
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op
|
||||
|
@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point)
|
||||
run(plan_desc);
|
||||
execution_plan_cache[key] = plan_desc;
|
||||
return quantized_output.view(orig_sizes);
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn");
|
||||
|
@ -252,7 +252,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
|
||||
|
@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn");
|
||||
|
@ -764,8 +764,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
const int64_t max_seqlen_batch_v = value.size(2);
|
||||
TORCH_CHECK(
|
||||
@ -806,7 +806,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
||||
num_heads/*int64_t h*/,
|
||||
max_seqlen_batch_q/*int64_t s_q*/,
|
||||
max_seqlen_batch_k/*int64_t s_kv*/,
|
||||
head_dim/*int64_t d*/,
|
||||
head_dim_qk/*int64_t d_qk*/,
|
||||
head_dim_v/*int64_t d_v*/,
|
||||
softmax_scale/*float scaling_factor*/,
|
||||
compute_logsumexp/* bool */,
|
||||
is_causal/* bool */,
|
||||
|
@ -194,12 +194,11 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
||||
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
const int64_t head_dim_qk = query.size(3);
|
||||
const int64_t head_dim_v = value.size(3);
|
||||
const int64_t max_seqlen_batch_q = query.size(1);
|
||||
const int64_t max_seqlen_batch_k = key.size(1);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
@ -207,7 +206,8 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
|
||||
num_heads /*int64_t h*/,
|
||||
max_seqlen_batch_q /*int64_t s_q*/,
|
||||
max_seqlen_batch_k /*int64_t s_kv*/,
|
||||
head_dim /*int64_t d*/,
|
||||
head_dim_qk /*int64_t d_qk*/,
|
||||
head_dim_v /*int64_t d_v*/,
|
||||
softmax_scale /*float scaling_factor*/,
|
||||
is_causal /*bool is_causal*/,
|
||||
dropout_p /*float dropout_probability*/,
|
||||
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,fail_accuracy,8
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,fail_to_run,0
|
||||
yolov3,pass,0
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
||||
|
||||
|
||||
|
||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
||||
pyhpc_isoneutral_mixing,pass,0
|
||||
|
||||
|
||||
|
||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
||||
|
||||
|
||||
|
||||
yolov3,fail_to_run,0
|
||||
yolov3,pass,0
|
||||
|
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,fail_accuracy,8
|
||||
|
|
@ -298,4 +298,4 @@ vision_maskrcnn,pass,28
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
||||
|
||||
|
||||
|
||||
yolov3,pass,2
|
||||
yolov3,pass,0
|
||||
|
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
||||
|
||||
|
||||
|
||||
yolov3,pass,9
|
||||
yolov3,pass,8
|
||||
|
|
@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316
|
||||
basic_gnn_gcn,float32,dynamic,default,1.074576405
|
||||
llama_v2_7b_16h,float32,dynamic,default,1.211740245
|
||||
resnet50,float32,dynamic,default,1.65984261
|
||||
timm_efficientnet,float32,static,cpp,2.271561735
|
||||
#timm_efficientnet,float32,static,cpp,2.1938112
|
||||
mobilenet_v3_large,float32,static,cpp,2.63375628
|
||||
timm_resnest,float32,dynamic,cpp,1.67998548
|
||||
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
|
||||
#hf_GPT2,float32,dynamic,cpp,
|
||||
hf_GPT2,float32,dynamic,cpp,1.379885175
|
||||
#hf_GPT2,float32,dynamic,cpp,1.292704418
|
||||
resnext50_32x4d,amp,static,default,1.461687045
|
||||
vgg16,amp,static,default,1.267194285
|
||||
hf_Longformer,amp,dynamic,default,0.997006035
|
||||
@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146
|
||||
llama,amp,static,default,1.32950568
|
||||
timm_regnet,amp,static,cpp,1.157188305
|
||||
lennard_jones,amp,static,cpp,2.240104485
|
||||
hf_T5_generate,amp,dynamic,cpp,1.447656135
|
||||
#hf_T5_generate,amp,dynamic,cpp,1.29339502
|
||||
timm_vovnet,amp,dynamic,cpp,1.07856471
|
||||
mobilenet_v2,amp,dynamic,cpp,2.27774577
|
||||
|
|
@ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs
|
||||
# We are primarily interested in tf32 datatype
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Enable FX graph caching
|
||||
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
|
||||
|
||||
def _reassign_parameters(model):
|
||||
# torch_geometric models register parameter as tensors due to
|
||||
|
@ -272,6 +272,38 @@ TEST(StaticRuntime, autogen_addr) {
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen__test_functorch_fallback) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %other: Tensor):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::_test_functorch_fallback(%self, %other)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto other0 = at::rand({6, 6, 6});
|
||||
std::vector<IValue> args{self0, other0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto other1 = at::rand({22, 22, 22});
|
||||
std::vector<IValue> args2{self1, other1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_argmax) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %dim: int?, %keepdim: bool):
|
||||
@ -4440,6 +4472,40 @@ TEST(StaticRuntime, autogen_masked_select) {
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_nonzero_static) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %size: int, %fill_value: int):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::nonzero_static(%self, %size, %fill_value)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto size0 = 1;
|
||||
auto fill_value0 = 1;
|
||||
std::vector<IValue> args{self0, size0, fill_value0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto size1 = 1;
|
||||
auto fill_value1 = 1;
|
||||
std::vector<IValue> args2{self1, size1, fill_value1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_gather) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool):
|
||||
@ -7106,222 +7172,6 @@ TEST(StaticRuntime, autogen_special_multigammaln) {
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_fft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_fft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_ifft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_ifft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_rfft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_rfft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_irfft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_irfft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_hfft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_hfft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_fft_ihfft) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
|
||||
%bias: None = prim::Constant()
|
||||
%ret = aten::fft_ihfft(%self, %n, %dim, %norm)
|
||||
%cloned = aten::clone(%ret, %bias)
|
||||
return (%cloned)
|
||||
)IR";
|
||||
|
||||
auto self0 = at::rand({6, 6, 6});
|
||||
auto n0 = 1;
|
||||
auto dim0 = 1;
|
||||
auto norm0 = "forward";
|
||||
std::vector<IValue> args{self0, n0, dim0, norm0};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
{},
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
|
||||
auto self1 = at::rand({22, 22, 22});
|
||||
auto n1 = 1;
|
||||
auto dim1 = 1;
|
||||
auto norm1 = "forward";
|
||||
std::vector<IValue> args2{self1, n1, dim1, norm1};
|
||||
testStaticRuntime(
|
||||
script,
|
||||
args,
|
||||
args2,
|
||||
/*use_allclose=*/false,
|
||||
/*use_equalnan=*/false,
|
||||
/*check_resize=*/true);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, autogen_linalg_cross) {
|
||||
const std::string script = R"IR(
|
||||
graph(%self: Tensor, %other: Tensor, %dim: int):
|
||||
|
@ -827,6 +827,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
"torch/csrc/mtia/Module.cpp",
|
||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||
|
@ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features:
|
||||
application on the Llama model.
|
||||
|
||||
|
||||
Step 1: build ``PipelineStage`` for execution
|
||||
*********************************************
|
||||
Step 1: build ``PipelineStage``
|
||||
*******************************
|
||||
|
||||
Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage``
|
||||
objects that wrap the part of the model running in that stage. The
|
||||
|
@ -779,4 +779,5 @@ Tensor class reference
|
||||
Tensor.where
|
||||
Tensor.xlogy
|
||||
Tensor.xlogy_
|
||||
Tensor.xpu
|
||||
Tensor.zero_
|
||||
|
@ -80,6 +80,48 @@ class WorkerServerTest(TestCase):
|
||||
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
|
||||
self.assertEqual(resp.status, 200)
|
||||
out = pickle.loads(resp.data)
|
||||
self.assertIsInstance(out, dict)
|
||||
self.assertIn("version", out)
|
||||
|
||||
@requires_cuda
|
||||
def test_dump_nccl_trace_pickle_with_params(self) -> None:
|
||||
with local_worker_server() as pool:
|
||||
# bad key - not lower case
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true"
|
||||
)
|
||||
self.assertEqual(resp.status, 400)
|
||||
# unknown key
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?unknownkey=true"
|
||||
)
|
||||
self.assertEqual(resp.status, 400)
|
||||
# bad value - not a bool
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool"
|
||||
)
|
||||
self.assertEqual(resp.status, 400)
|
||||
# bad value - value not lowercase
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=True"
|
||||
)
|
||||
self.assertEqual(resp.status, 400)
|
||||
# good key and value
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=true"
|
||||
)
|
||||
self.assertEqual(resp.status, 200)
|
||||
# good key and value
|
||||
resp = pool.request(
|
||||
"POST", "/handler/dump_nccl_trace_pickle?includestacktraces=true"
|
||||
)
|
||||
self.assertEqual(resp.status, 200)
|
||||
# multiple good keys and values
|
||||
resp = pool.request(
|
||||
"POST",
|
||||
"/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true",
|
||||
)
|
||||
self.assertEqual(resp.status, 200)
|
||||
|
||||
def test_tcp(self) -> None:
|
||||
import requests
|
||||
|
@ -13,7 +13,6 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from contextlib import closing
|
||||
from unittest import mock
|
||||
@ -23,12 +22,13 @@ import torch.distributed.run as launch
|
||||
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
|
||||
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
|
||||
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
from torch.distributed.elastic.utils import get_socket_with_port
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@ -63,19 +63,7 @@ class MockException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ElasticLaunchTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# start a standalone, single process etcd server to use for all tests
|
||||
cls._etcd_server = EtcdServer()
|
||||
cls._etcd_server.start()
|
||||
cls._etcd_endpoint = cls._etcd_server.get_endpoint()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# stop the standalone etcd server
|
||||
cls._etcd_server.stop()
|
||||
|
||||
class ElasticLaunchTest(TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
@ -103,8 +91,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -156,8 +142,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -187,8 +171,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
world_size = 1
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -220,8 +202,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
|
||||
os.environ["PET_NNODES"] = str(nnodes)
|
||||
os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node)
|
||||
os.environ["PET_RDZV_BACKEND"] = "etcd"
|
||||
os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
|
||||
os.environ["PET_RDZV_ID"] = run_id
|
||||
os.environ["PET_MONITOR_INTERVAL"] = "1"
|
||||
os.environ["PET_START_METHOD"] = "spawn"
|
||||
@ -250,8 +230,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={nnodes}",
|
||||
f"--nproc-per-node={nproc_type}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -272,7 +250,8 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
def test_nproc_launch_auto_configurations(self):
|
||||
@patch("torch.cuda.is_available", return_value=False)
|
||||
def test_nproc_launch_auto_configurations(self, _mock1):
|
||||
self._test_nproc_launch_configuration("auto", os.cpu_count())
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
@ -310,8 +289,9 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={min_nodes}:{max_nodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{get_free_port()}",
|
||||
"--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -343,8 +323,9 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={min_nodes}:{max_nodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{get_free_port()}",
|
||||
"--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--max-restarts=0",
|
||||
@ -376,8 +357,9 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={min_nodes}:{max_nodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{get_free_port()}",
|
||||
"--rdzv_conf=timeout=5",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--max-restarts=0",
|
||||
@ -452,8 +434,9 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
args = [
|
||||
f"--nnodes={min_nodes}:{max_nodes}",
|
||||
f"--nproc-per-node={nproc_per_node}",
|
||||
"--rdzv-backend=etcd",
|
||||
f"--rdzv-endpoint={self._etcd_endpoint}",
|
||||
"--rdzv-backend=c10d",
|
||||
f"--rdzv-endpoint=localhost:{get_free_port()}",
|
||||
"--rdzv_conf=timeout=5",
|
||||
f"--rdzv-id={run_id}",
|
||||
"--monitor-interval=1",
|
||||
"--start-method=spawn",
|
||||
@ -608,21 +591,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
is_torchelastic_launched = fp.readline()
|
||||
self.assertEqual("False", is_torchelastic_launched)
|
||||
|
||||
def test_init_method_tcp(self):
|
||||
port = get_free_port()
|
||||
with patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
path("bin/test_script_init_method.py"),
|
||||
f"--init-method=tcp://localhost:{port}",
|
||||
"--rank=0",
|
||||
"--world-size=1",
|
||||
],
|
||||
):
|
||||
runpy.run_path(sys.argv[0], run_name="__main__")
|
||||
# nothing to validate, just make sure it runs
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
@ -642,27 +610,6 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
)
|
||||
# nothing to validate, just make sure it runs
|
||||
|
||||
def test_init_method_env(self):
|
||||
port = get_free_port()
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"RANK": "0",
|
||||
"WORLD_SIZE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": str(port),
|
||||
},
|
||||
), patch.object(
|
||||
sys,
|
||||
"argv",
|
||||
[
|
||||
path("bin/test_script_init_method.py"),
|
||||
"--init-method=env://",
|
||||
],
|
||||
):
|
||||
runpy.run_path(sys.argv[0], run_name="__main__")
|
||||
# nothing to validate, just make sure it runs
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
|
||||
)
|
||||
@ -681,3 +628,7 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
# nothing to validate, just make sure it runs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -3662,7 +3662,8 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_trace_while_active(self, timing_enabled):
|
||||
@parametrize("only_active", [True, False])
|
||||
def test_trace_while_active(self, timing_enabled, only_active):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
for c in self.children_pipes:
|
||||
self.assertEqual(c.recv(), "next")
|
||||
@ -3683,17 +3684,26 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
if self.rank != 0:
|
||||
pg.allreduce(a).wait()
|
||||
e.synchronize()
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = pickle.loads(
|
||||
torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active)
|
||||
)
|
||||
t = t["entries"]
|
||||
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]["collective_seq_id"], 1)
|
||||
self.assertEqual(t[-1]["state"], "completed")
|
||||
else:
|
||||
self.assertEqual(t[-1]["collective_seq_id"], 2)
|
||||
self.assertEqual(
|
||||
t[-1]["state"], self.started_or_scheduled(timing_enabled)
|
||||
)
|
||||
if only_active:
|
||||
if self.rank == 0:
|
||||
self.assertEqual(len(t), 0)
|
||||
else:
|
||||
self.assertEqual(len(t), 1)
|
||||
if not only_active:
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
|
||||
self.assertEqual(t[-1]["collective_seq_id"], 1)
|
||||
self.assertEqual(t[-1]["state"], "completed")
|
||||
else:
|
||||
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
|
||||
self.assertEqual(t[-1]["collective_seq_id"], 2)
|
||||
self.assertEqual(
|
||||
t[-1]["state"], self.started_or_scheduled(timing_enabled)
|
||||
)
|
||||
|
||||
self.parent.send("next")
|
||||
self.assertEqual("next", self.parent.recv())
|
||||
|
@ -1084,14 +1084,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
||||
# far from an exhaustive check of all the expected guards, just check a couple of them.
|
||||
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
|
||||
"""local "L['self']" ID_MATCH"""
|
||||
).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check(
|
||||
f"""{expected_guard_source} "L['self'].net" ID_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
||||
f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
||||
).check(
|
||||
f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH"""
|
||||
f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH"""
|
||||
).run(
|
||||
GUARDS_FILE.getvalue()
|
||||
)
|
||||
|
@ -304,13 +304,12 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
||||
def test_torch_guards_stack_frame_register_inlining_disable(self):
|
||||
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
|
||||
x = torch.tensor([0.5, 0.5])
|
||||
|
||||
class encoder(torch.nn.Module):
|
||||
def __init__(self, y):
|
||||
super().__init__()
|
||||
self.register_parameter("param", y)
|
||||
self.a = y
|
||||
|
||||
@torch._dynamo.disable
|
||||
def helper(self, x, y):
|
||||
@ -318,9 +317,9 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def forward(self, a, *args):
|
||||
x = a + a
|
||||
return self.helper(x, self.param)
|
||||
return self.helper(x, self.a)
|
||||
|
||||
e = encoder(y)
|
||||
e = encoder(2.0)
|
||||
|
||||
seen_frames = []
|
||||
import contextlib
|
||||
@ -465,6 +464,44 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_assume_constant_result_on_user_defined_fn(self):
|
||||
@torch._dynamo.assume_constant_result
|
||||
def const_fn(n, s):
|
||||
return torch.full([n], s)
|
||||
|
||||
def fn(B):
|
||||
B = const_fn(B.size(0), 13)
|
||||
X = B * 2
|
||||
return X.tolist()
|
||||
|
||||
B_list = [8] * 32
|
||||
|
||||
B = torch.tensor(B_list, dtype=torch.int32)
|
||||
torch._dynamo.decorators.mark_static(B, 0)
|
||||
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
self.assertEqual(
|
||||
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
|
||||
)
|
||||
|
||||
def test_assume_constant_result_on_computation_with_graph_input(self):
|
||||
@torch._dynamo.assume_constant_result
|
||||
def check(y):
|
||||
return y[0].item() == 1
|
||||
|
||||
def fn(x, y):
|
||||
if check(y):
|
||||
return x + 2
|
||||
else:
|
||||
return x + 1
|
||||
|
||||
y = torch.tensor([1])
|
||||
x = torch.tensor(1)
|
||||
|
||||
self.assertEqual(fn(x, y), torch.compile(fn)(x, y))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -253,6 +253,7 @@ Target Expressions:
|
||||
==> (>= 0 s1)
|
||||
==> (>= 0 s2)
|
||||
==> (>= 0 s3)
|
||||
==> (>= 9223372036854775806 s0)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
@ -286,14 +287,14 @@ Failure occurred while running node:
|
||||
Model:
|
||||
==> L['shape'][0]: 1
|
||||
==> L['shape'][1]: 1
|
||||
==> L['shape'][2]: 0
|
||||
==> L['shape'][2]: 2
|
||||
==> L['x'].size()[0]: 3
|
||||
==> L['x'].storage_offset(): 0
|
||||
==> L['x'].stride()[0]: 1
|
||||
==> s0: 3
|
||||
==> s1: 1
|
||||
==> s2: 1
|
||||
==> s3: 0
|
||||
==> s3: 2
|
||||
|
||||
Assertions:
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
@ -317,6 +318,10 @@ Target Expressions:
|
||||
==> (== L['shape'][2] s3)
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 0)
|
||||
==> (>= 9223372036854775806 s0)
|
||||
==> (>= 9223372036854775807 s1)
|
||||
==> (>= 9223372036854775807 s2)
|
||||
==> (>= 9223372036854775807 s3)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
|
||||
|
@ -3473,6 +3473,7 @@ class GraphModule(torch.nn.Module):
|
||||
]
|
||||
false_guard_code = [
|
||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
||||
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
|
||||
]
|
||||
test_symbool_guards(
|
||||
f,
|
||||
|
@ -3,7 +3,6 @@ import enum
|
||||
import functools
|
||||
import pprint
|
||||
import re
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
@ -2860,7 +2859,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1)
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None
|
||||
batched_outputs = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
|
||||
@ -2896,7 +2895,7 @@ class GraphModule(torch.nn.Module):
|
||||
jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None
|
||||
|
||||
unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
|
||||
return (unflatten, diff_primals, o)
|
||||
return (unflatten,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -2964,8 +2963,8 @@ class GraphModule(torch.nn.Module):
|
||||
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
_wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3)
|
||||
child_4 = torch._C._functorch._wrap_for_grad(child_3, 3)
|
||||
_wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
|
||||
child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
@ -3002,7 +3001,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1)
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None
|
||||
child_5 = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
|
||||
@ -3041,17 +3040,10 @@ class GraphModule(torch.nn.Module):
|
||||
unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""",
|
||||
)
|
||||
|
||||
# Python 3.10 and 3.11 produces slightly different graphs
|
||||
if sys.version_info[:2] > (3, 10):
|
||||
self.assertExpectedInline(
|
||||
actual.split("\n")[-2],
|
||||
""" return (unflatten, child_2, _wrap_for_grad_1, child_3, child_4, o)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
actual.split("\n")[-2],
|
||||
""" return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
actual.split("\n")[-2],
|
||||
""" return (unflatten,)""",
|
||||
)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_hessian_disable_capture(self):
|
||||
@ -3160,7 +3152,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
||||
batched_outputs = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
@ -3172,7 +3164,7 @@ class GraphModule(torch.nn.Module):
|
||||
split_1: "f32[12, 4, 3]" = split[0]; split = None
|
||||
|
||||
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
|
||||
return (output_input, diff_primals, o)
|
||||
return (output_input,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3243,7 +3235,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
||||
batched_outputs = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
@ -3255,7 +3247,7 @@ class GraphModule(torch.nn.Module):
|
||||
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
||||
|
||||
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
||||
return (output_input, diff_primals, o)
|
||||
return (output_input,)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3328,7 +3320,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
|
||||
batched_outputs = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
|
||||
@ -3340,7 +3332,7 @@ class GraphModule(torch.nn.Module):
|
||||
split_1: "f32[12, 3, 4]" = split[0]; split = None
|
||||
|
||||
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
|
||||
return (output_input, aux_1, diff_primals, o)
|
||||
return (output_input, aux_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -3776,7 +3768,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad_input_1, y)
|
||||
return (y, grad_input_1)
|
||||
""",
|
||||
)
|
||||
|
||||
@ -5187,10 +5179,10 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
|
||||
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
|
||||
def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
|
||||
l_self_tensor_constant0 = L_self_tensor_constant0
|
||||
|
||||
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
|
||||
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None
|
||||
|
||||
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
|
||||
|
||||
@ -5209,16 +5201,16 @@ class GraphModule(torch.nn.Module):
|
||||
actual,
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
||||
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
|
||||
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
|
||||
def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
|
||||
getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_
|
||||
getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_
|
||||
l_flat_tangents_1_ = L_flat_tangents_1_
|
||||
|
||||
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
|
||||
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None
|
||||
|
||||
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
|
||||
|
||||
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
|
||||
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None
|
||||
return (mul_tensor,)
|
||||
""",
|
||||
)
|
||||
|
@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
|
||||
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
|
||||
> Right: {0: 0, 1: 1}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Right: {}
|
||||
==> var_to_sources: values don't match.
|
||||
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
|
||||
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
|
||||
> Left: 2
|
||||
> Right: 0
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
|
||||
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
|
||||
> Right: {}
|
||||
""",
|
||||
)
|
||||
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
|
||||
> Left: {s0: 3}
|
||||
> Right: {}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
|
||||
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
|
||||
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
|
||||
==> var_to_range: values don't match.
|
||||
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
|
||||
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
|
||||
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
@ -201,19 +201,6 @@ class TestDynamismExpression(TestCase):
|
||||
dynamic_shapes={"x": {0: dim_x}},
|
||||
)
|
||||
|
||||
def test_export_slice_maxsize(self):
|
||||
class Slice(torch.nn.Module):
|
||||
def forward(self, *args):
|
||||
return torch.ops.aten.slice.Tensor(*args)
|
||||
|
||||
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
|
||||
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
|
||||
torch.export.export(
|
||||
Slice(),
|
||||
inp,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_export_constraints_error(self):
|
||||
class ConflictingConstraints(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -5196,7 +5183,7 @@ def forward(self, x, y):
|
||||
}
|
||||
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_disable_forced_specializations_ok(self):
|
||||
def test_disable_forced_specializations(self):
|
||||
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
|
||||
# both behave correctly, avoiding forced specializations and deferring to runtime.
|
||||
# case 1: modulo guards
|
||||
|
@ -312,6 +312,31 @@ class TestUnflatten(TestCase):
|
||||
export_module.module(), unflattened, (torch.randn((2, 3)),)
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
||||
def test_unflatten_preserve_with_unused_input(self):
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x, a, b):
|
||||
return x + a, b
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
|
||||
def forward(self, x, y):
|
||||
a, b = torch.topk(y, 2)
|
||||
return self.m1(x, a, b)[0]
|
||||
|
||||
ep = torch.export.export(
|
||||
M(),
|
||||
(torch.randn(2), torch.randn(5)),
|
||||
preserve_module_call_signature=("m1",),
|
||||
strict=False,
|
||||
)
|
||||
ep.graph.eliminate_dead_code()
|
||||
unflattened = unflatten(ep)
|
||||
self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
|
||||
|
||||
def test_unflatten_wrong_input(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user