mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			42 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3ddec713b8 | |||
| 85eeb90d2c | |||
| 7f6daf289b | |||
| 3d55d84ec2 | |||
| bb2a995529 | |||
| 9538bf4e7c | |||
| 219da29dfd | |||
| fb013ecb24 | |||
| 6af4c6acad | |||
| 786c24a4cd | |||
| 5d8c7f39d4 | |||
| c9c1fed065 | |||
| 94fea82d66 | |||
| 447173198b | |||
| b79d056e76 | |||
| eb567b1f40 | |||
| 1dd2431f86 | |||
| 5fcb5f0c8b | |||
| a55d0d9718 | |||
| 8c1247cffb | |||
| 70a1e85718 | |||
| adb699189b | |||
| 45dccfddcd | |||
| 3e09123797 | |||
| 61f922c2ca | |||
| 984b1a8c35 | |||
| 205410cb44 | |||
| cac7a22b92 | |||
| 8a09940a54 | |||
| 1d233b8f50 | |||
| 491c4a5dcb | |||
| 4345d98663 | |||
| a838e90964 | |||
| 29081059b6 | |||
| f8c45996d5 | |||
| c13e03c874 | |||
| 053930e194 | |||
| 9a38cae299 | |||
| 55901fb3da | |||
| fc77fdca6f | |||
| 648625b230 | |||
| 207c2248a8 | 
| @ -1099,7 +1099,6 @@ exclude_patterns = [ | |||||||
|     'test/test_namedtuple_return_api.py', |     'test/test_namedtuple_return_api.py', | ||||||
|     'test/test_native_functions.py', |     'test/test_native_functions.py', | ||||||
|     'test/test_native_mha.py', |     'test/test_native_mha.py', | ||||||
|     'test/test_nestedtensor.py', |  | ||||||
|     'test/test_nn.py', |     'test/test_nn.py', | ||||||
|     'test/test_out_dtype_op.py', |     'test/test_out_dtype_op.py', | ||||||
|     'test/test_overrides.py', |     'test/test_overrides.py', | ||||||
|  | |||||||
| @ -462,7 +462,7 @@ inline Tensor _sum_to( | |||||||
|     reduce_dims.push_back(i); |     reduce_dims.push_back(i); | ||||||
|   } |   } | ||||||
|   for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) { |   for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) { | ||||||
|     if (shape[i - leading_dims] == 1 && |     if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) && | ||||||
|         TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) { |         TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) { | ||||||
|       reduce_dims.push_back(i); |       reduce_dims.push_back(i); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -478,8 +478,6 @@ namespace impl { | |||||||
| // (maybe except for some internal prim ops). | // (maybe except for some internal prim ops). | ||||||
| using GenericList = List<IValue>; | using GenericList = List<IValue>; | ||||||
|  |  | ||||||
| const IValue* ptr_to_first_element(const GenericList& list); |  | ||||||
|  |  | ||||||
| } | } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -350,11 +350,4 @@ void List<T>::unsafeSetElementType(TypePtr t) { | |||||||
|   impl_->elementType = std::move(t); |   impl_->elementType = std::move(t); | ||||||
| } | } | ||||||
|  |  | ||||||
| namespace impl { |  | ||||||
|  |  | ||||||
| inline const IValue* ptr_to_first_element(const GenericList& list) { |  | ||||||
|   return &list.impl_->list[0]; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| } |  | ||||||
| } | } | ||||||
|  | |||||||
| @ -1195,15 +1195,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho | |||||||
| #undef REPR | #undef REPR | ||||||
| } | } | ||||||
|  |  | ||||||
| static Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt, |  | ||||||
|              const optional<int64_t> win_lengthOpt, const Tensor& window, |  | ||||||
|              const bool center, const bool normalized, const optional<bool> onesidedOpt, |  | ||||||
|              const optional<int64_t> lengthOpt) { |  | ||||||
|   return at::native::istft( |  | ||||||
|       self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized, |  | ||||||
|       onesidedOpt, lengthOpt, /*return_complex=*/false); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { | void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) { | ||||||
|   const auto input_sizes = input.sizes(); |   const auto input_sizes = input.sizes(); | ||||||
|   const auto input_strides = input.strides(); |   const auto input_strides = input.strides(); | ||||||
|  | |||||||
| @ -210,7 +210,6 @@ | |||||||
| #include <ATen/ops/zeros_native.h> | #include <ATen/ops/zeros_native.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #include <c10/util/StringUtil.h> |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
| #include <utility> | #include <utility> | ||||||
|  | |||||||
| @ -13,7 +13,8 @@ void run_cudnn_SDP_fprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool isTraining, |     bool isTraining, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
| @ -34,7 +35,8 @@ void run_cudnn_SDP_bprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
|     float dropout_probability, |     float dropout_probability, | ||||||
| @ -128,7 +130,8 @@ struct MHAParams { | |||||||
|   int64_t h; |   int64_t h; | ||||||
|   int64_t s_q; |   int64_t s_q; | ||||||
|   int64_t s_kv; |   int64_t s_kv; | ||||||
|   int64_t d; |   int64_t d_qk; | ||||||
|  |   int64_t d_v; | ||||||
|   double dropout_probability; |   double dropout_probability; | ||||||
|   bool is_causal; |   bool is_causal; | ||||||
|   bool return_softmaxstats; |   bool return_softmaxstats; | ||||||
| @ -140,7 +143,8 @@ void setMHAParams( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     const Tensor& q, |     const Tensor& q, | ||||||
|     const Tensor& k, |     const Tensor& k, | ||||||
|     const Tensor& v, |     const Tensor& v, | ||||||
| @ -155,7 +159,8 @@ void setMHAParams( | |||||||
|   } |   } | ||||||
|   params.b = b; |   params.b = b; | ||||||
|   params.h = h; |   params.h = h; | ||||||
|   params.d = d; |   params.d_qk = d_qk; | ||||||
|  |   params.d_v = d_v; | ||||||
|   params.s_q = s_q; |   params.s_q = s_q; | ||||||
|   params.s_kv = s_kv; |   params.s_kv = s_kv; | ||||||
|   params.dropout_probability = dropout_probability; |   params.dropout_probability = dropout_probability; | ||||||
| @ -193,7 +198,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> { | |||||||
|       int64_t h, |       int64_t h, | ||||||
|       int64_t s_q, |       int64_t s_q, | ||||||
|       int64_t s_kv, |       int64_t s_kv, | ||||||
|       int64_t d, |       int64_t d_qk, | ||||||
|  |       int64_t d_v, | ||||||
|       const Tensor& q, |       const Tensor& q, | ||||||
|       const Tensor& k, |       const Tensor& k, | ||||||
|       const Tensor& v, |       const Tensor& v, | ||||||
| @ -206,7 +212,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> { | |||||||
|         h, |         h, | ||||||
|         s_q, |         s_q, | ||||||
|         s_kv, |         s_kv, | ||||||
|         d, |         d_qk, | ||||||
|  |         d_v, | ||||||
|         q, |         q, | ||||||
|         k, |         k, | ||||||
|         v, |         v, | ||||||
| @ -249,7 +256,8 @@ auto build_graph_and_tensors( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool return_softmaxstats, |     bool return_softmaxstats, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
| @ -383,7 +391,8 @@ auto build_graph_and_tensors_backward( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
|     float dropout_probability, |     float dropout_probability, | ||||||
| @ -514,7 +523,8 @@ void run_cudnn_SDP_fprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool return_softmaxstats, |     bool return_softmaxstats, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
| @ -528,7 +538,7 @@ void run_cudnn_SDP_fprop( | |||||||
|     Tensor& dropoutoffset) { |     Tensor& dropoutoffset) { | ||||||
|   cudnnHandle_t handle = getCudnnHandle(); |   cudnnHandle_t handle = getCudnnHandle(); | ||||||
|   o = at::empty_strided( |   o = at::empty_strided( | ||||||
|       {b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options()); |       {b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options()); | ||||||
|   if (return_softmaxstats) { |   if (return_softmaxstats) { | ||||||
|     // TODO(eqy): verify that this is correct |     // TODO(eqy): verify that this is correct | ||||||
|     softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); |     softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat)); | ||||||
| @ -539,7 +549,8 @@ void run_cudnn_SDP_fprop( | |||||||
|       h, |       h, | ||||||
|       s_q, |       s_q, | ||||||
|       s_kv, |       s_kv, | ||||||
|       d, |       d_qk, | ||||||
|  |       d_v, | ||||||
|       q, |       q, | ||||||
|       k, |       k, | ||||||
|       v, |       v, | ||||||
| @ -556,7 +567,8 @@ void run_cudnn_SDP_fprop( | |||||||
|         h, |         h, | ||||||
|         s_q, |         s_q, | ||||||
|         s_kv, |         s_kv, | ||||||
|         d, |         d_qk, | ||||||
|  |         d_v, | ||||||
|         scaling_factor, |         scaling_factor, | ||||||
|         return_softmaxstats, |         return_softmaxstats, | ||||||
|         is_causal, |         is_causal, | ||||||
| @ -599,7 +611,8 @@ void run_cudnn_SDP_bprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_qk, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
|     float dropout_probability, |     float dropout_probability, | ||||||
| @ -623,7 +636,18 @@ void run_cudnn_SDP_bprop( | |||||||
|   } |   } | ||||||
|   cudnnHandle_t handle = getCudnnHandle(); |   cudnnHandle_t handle = getCudnnHandle(); | ||||||
|   auto key = MHACacheKeyWrapper( |   auto key = MHACacheKeyWrapper( | ||||||
|       b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true); |       b, | ||||||
|  |       h, | ||||||
|  |       s_q, | ||||||
|  |       s_kv, | ||||||
|  |       d_qk, | ||||||
|  |       d_v, | ||||||
|  |       q, | ||||||
|  |       k, | ||||||
|  |       v, | ||||||
|  |       dropout_probability, | ||||||
|  |       is_causal, | ||||||
|  |       true); | ||||||
|   auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); |   auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); | ||||||
|   graph_and_tensors_backward graph_and_tensors_backward_values; |   graph_and_tensors_backward graph_and_tensors_backward_values; | ||||||
|   if (graph_and_tensors_backward_ptr) { |   if (graph_and_tensors_backward_ptr) { | ||||||
| @ -634,7 +658,8 @@ void run_cudnn_SDP_bprop( | |||||||
|         h, |         h, | ||||||
|         s_q, |         s_q, | ||||||
|         s_kv, |         s_kv, | ||||||
|         d, |         d_qk, | ||||||
|  |         d_v, | ||||||
|         scaling_factor, |         scaling_factor, | ||||||
|         is_causal, |         is_causal, | ||||||
|         dropout_probability, |         dropout_probability, | ||||||
| @ -684,5 +709,4 @@ void run_cudnn_SDP_bprop( | |||||||
|  |  | ||||||
| } // namespace native | } // namespace native | ||||||
| } // namespace at | } // namespace at | ||||||
|  |  | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -9,7 +9,8 @@ void run_cudnn_SDP_fprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_k, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool isTraining, |     bool isTraining, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
| @ -27,7 +28,8 @@ void run_cudnn_SDP_bprop( | |||||||
|     int64_t h, |     int64_t h, | ||||||
|     int64_t s_q, |     int64_t s_q, | ||||||
|     int64_t s_kv, |     int64_t s_kv, | ||||||
|     int64_t d, |     int64_t d_k, | ||||||
|  |     int64_t d_v, | ||||||
|     float scaling_factor, |     float scaling_factor, | ||||||
|     bool is_causal, |     bool is_causal, | ||||||
|     float dropout_probability, |     float dropout_probability, | ||||||
|  | |||||||
| @ -18,26 +18,21 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]], | |||||||
|   /* coefficients in rational expansion */ |   /* coefficients in rational expansion */ | ||||||
|  |  | ||||||
|   float y_abs = abs(y); |   float y_abs = abs(y); | ||||||
|   if(y_abs > 1.0f){{ |   if (y_abs >= 1.0f) {{ | ||||||
|     output[index] = NAN; |     output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y)); | ||||||
|     return; |     return; | ||||||
|   }} |   }} | ||||||
|   if(y_abs == 1.0f){{ |   if (y_abs <= 0.7f) {{ | ||||||
|     output[index] = copysign(INFINITY, y); |  | ||||||
|     return; |  | ||||||
|   }} |  | ||||||
|   if(y_abs <= 0.7f) {{ |  | ||||||
|     z = y * y; |     z = y * y; | ||||||
|     num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); |     num = ((a[3] * z + a[2]) * z + a[1])*z + a[0]; | ||||||
|     dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f); |     dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f; | ||||||
|     x = y * num / dem; |     x = y * num / dem; | ||||||
|   }} |   }} else {{ | ||||||
|   else{{ |  | ||||||
|     z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); |     z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); | ||||||
|     num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; |     num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; | ||||||
|     dem = (d[1]*z + d[0])*z + 1.0f; |     dem = (d[1] * z + d[0]) * z + 1.0f; | ||||||
|     x = copysign(num, y) / dem; |     x = copysign(num, y) / dem; | ||||||
|   }} |   }} | ||||||
|  |  | ||||||
|   output[index] = x; |   output[index] = {0}(x); | ||||||
| }})METAL"; | }})METAL"; | ||||||
| @ -143,7 +143,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s | |||||||
|   Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); |   Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>()); |     string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>()); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
|  |  | ||||||
| @ -193,8 +193,8 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps) | |||||||
|   Tensor output_ = at::empty_like(self, self.suggest_memory_format()); |   Tensor output_ = at::empty_like(self, self.suggest_memory_format()); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = |     string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + | ||||||
|         "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>()); |         std::to_string(negative_slope.to<double>()); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
|       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); |       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); | ||||||
| @ -242,7 +242,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out) | |||||||
|   MPSStream* stream = at::mps::getCurrentMPSStream(); |   MPSStream* stream = at::mps::getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim); |     string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
|  |  | ||||||
| @ -285,7 +285,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out) | |||||||
|   MPSStream* stream = at::mps::getCurrentMPSStream(); |   MPSStream* stream = at::mps::getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim); |     string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output)); |       MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output)); | ||||||
|       MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); |       MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output)); | ||||||
| @ -539,8 +539,8 @@ TORCH_IMPL_FUNC(threshold_out_mps) | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to<double>()) + ":" + |     string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) + | ||||||
|         to_string(value.to<double>()); |         ":" + std::to_string(value.to<double>()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
| @ -587,7 +587,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps) | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = |     string key = | ||||||
|         "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to<double>()); |         "threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
| @ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self, | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to<double>()) + ":" + |     string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" + | ||||||
|         to_string(scale.to<double>()) + ":" + to_string(input_scale.to<double>()); |         std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
| @ -916,8 +916,8 @@ TORCH_IMPL_FUNC(elu_backward_out_mps) | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + |     string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + | ||||||
|         to_string(alpha.to<double>()) + ":" + to_string(scale.to<double>()) + ":" + |         std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" + | ||||||
|         to_string(input_scale.to<double>()) + ":" + to_string(is_result); |         std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); |       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); | ||||||
| @ -1010,7 +1010,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim); |     string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); | ||||||
|       NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor |       NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor | ||||||
| @ -1052,7 +1052,7 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim); |     string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self)); | ||||||
|       MPSGraphTensor* gradOutputTensor = |       MPSGraphTensor* gradOutputTensor = | ||||||
| @ -1855,8 +1855,8 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output, | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to<double>()) + |     string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + | ||||||
|         ":" + to_string(max.to<double>()); |         std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); |       MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); | ||||||
|  | |||||||
| @ -136,8 +136,8 @@ static Tensor& addmv_out_mps_impl(const Tensor& self, | |||||||
|   Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1); |   Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) + |     string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + | ||||||
|         ":" + to_string(alpha_.toDouble()); |         std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble()); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); |       MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec); | ||||||
|       MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |       MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) { | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble()); |     string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); |       MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); | ||||||
|  | |||||||
| @ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | |||||||
|  |  | ||||||
|     string bias_shape_key; |     string bias_shape_key; | ||||||
|     if (bias_defined) { |     if (bias_defined) { | ||||||
|       bias_shape_key = to_string(bias_shape[0]); |       bias_shape_key = std::to_string(bias_shape[0]); | ||||||
|     } else { |     } else { | ||||||
|       bias_shape_key = "nobias"; |       bias_shape_key = "nobias"; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     string key; |     string key; | ||||||
|     if (is3DConv) { |     if (is3DConv) { | ||||||
|       key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) + |       key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" + |           std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + | ||||||
|           to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) + |           std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + | ||||||
|           ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + |           std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           bias_shape_key; |           mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; | ||||||
|  |  | ||||||
|     } else { |     } else { | ||||||
|       key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) + |       key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + |           std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + | ||||||
|           to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + |           std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           to_string(bias_defined) + ":" + bias_shape_key; |           mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); |     MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); | ||||||
| @ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, | |||||||
|     NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string key; |     string key; | ||||||
|     if (is3DConv) { |     if (is3DConv) { | ||||||
|       key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" + |       key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + |           ":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + | ||||||
|           ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + |           std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + | ||||||
|           to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" + |           std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           string([ns_shape_key UTF8String]); |           getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); | ||||||
|  |  | ||||||
|     } else { |     } else { | ||||||
|       key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + |       key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + |           std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + | ||||||
|           to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + |           std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); |           getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); | ||||||
|     } |     } | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
| @ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | |||||||
|     NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string key; |     string key; | ||||||
|     if (is3DConv) { |     if (is3DConv) { | ||||||
|       key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + |       key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + |           std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + | ||||||
|           to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + |           std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + | ||||||
|           to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key + |           std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); |           getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); | ||||||
|     } else { |     } else { | ||||||
|       key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + |       key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + | ||||||
|           to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + |           std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + | ||||||
|           to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key + |           std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + | ||||||
|           getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); |           getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); | ||||||
|     } |     } | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|  | |||||||
| @ -63,7 +63,7 @@ Tensor& random_mps_impl(Tensor& self, | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" + |     string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" + | ||||||
|         to_string(val1) + ":" + to_string(val2); |         std::to_string(val1) + ":" + std::to_string(val2); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       newCachedGraph->stateTensor = |       newCachedGraph->stateTensor = | ||||||
|           mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); |           mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]); | ||||||
| @ -469,7 +469,7 @@ static Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self, | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample); |     string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSShape* prob_shape = getMPSShape(self_v); |       MPSShape* prob_shape = getMPSShape(self_v); | ||||||
|       newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); |       newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); | ||||||
|  | |||||||
| @ -236,7 +236,7 @@ static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& gra | |||||||
|   MPSStream* stream = getCurrentMPSStream(); |   MPSStream* stream = getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" + |     string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" + | ||||||
|         getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); |         getTensorsStringKey({input_reshaped, weight, grad_output_reshaped}); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped); | ||||||
|  | |||||||
| @ -229,8 +229,8 @@ static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input, | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl"); |     string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl"); | ||||||
|     key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" + |     key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" + | ||||||
|         to_string(alpha.toDouble()); |         std::to_string(alpha.toDouble()); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); |       MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input); | ||||||
| @ -331,8 +331,8 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias, | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) + |     string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + | ||||||
|         ":" + to_string(alpha.toDouble()); |         std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* selfTensor = nil; |       MPSGraphTensor* selfTensor = nil; | ||||||
|       MPSGraphTensor* otherTensor = nil; |       MPSGraphTensor* otherTensor = nil; | ||||||
| @ -615,8 +615,8 @@ Tensor& addr_out_mps(const Tensor& self, | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) + |     string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + | ||||||
|         ":" + to_string(alpha.toDouble()); |         std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); |       MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape); | ||||||
|       MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); |       MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape); | ||||||
|  | |||||||
| @ -69,7 +69,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output, | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) + |     string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) + | ||||||
|         getTensorsStringKey({input, target, grad_output}); |         getTensorsStringKey({input, target, grad_output}); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); |       newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); | ||||||
| @ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg, | |||||||
|   } |   } | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) + |     string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) + | ||||||
|         to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + |         std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + | ||||||
|         to_string(isTargetCasted) + ":" + reductionToString(reduction); |         ":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); | ||||||
| @ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output, | |||||||
|     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|  |  | ||||||
|     // TODO: Make the key |     // TODO: Make the key | ||||||
|     string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" + |     string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) + | ||||||
|         reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + |         ":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" + | ||||||
|         getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight); |         getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape); | ||||||
|       MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); |       MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape); | ||||||
| @ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input, | |||||||
|     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|  |  | ||||||
|     string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + |     string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + | ||||||
|         to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); |         std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       // smooth_l1_loss_mps: |       // smooth_l1_loss_mps: | ||||||
|       // ln = 0.5 * ( xn - yn ) ^ 2 / beta,       if |xn - yn| < beta |       // ln = 0.5 * ( xn - yn ) ^ 2 / beta,       if |xn - yn| < beta | ||||||
| @ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" + |     string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" + | ||||||
|         reductionToString(reduction) + ":" + to_string(beta); |         reductionToString(reduction) + ":" + std::to_string(beta); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); | ||||||
|  | |||||||
| @ -106,7 +106,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste | |||||||
|     auto stream = getCurrentMPSStream(); |     auto stream = getCurrentMPSStream(); | ||||||
|     auto mpsDataType = getMPSDataType(result); |     auto mpsDataType = getMPSDataType(result); | ||||||
|     @autoreleasepool { |     @autoreleasepool { | ||||||
|       string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); |       string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); | ||||||
|       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); |       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); | ||||||
|       if (!cachedGraph) { |       if (!cachedGraph) { | ||||||
|         cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() { |         cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() { | ||||||
| @ -173,7 +173,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step | |||||||
|     auto stream = getCurrentMPSStream(); |     auto stream = getCurrentMPSStream(); | ||||||
|     auto mpsDataType = getMPSDataType(result); |     auto mpsDataType = getMPSDataType(result); | ||||||
|     @autoreleasepool { |     @autoreleasepool { | ||||||
|       string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size); |       string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size); | ||||||
|       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); |       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); | ||||||
|       if (!cachedGraph) { |       if (!cachedGraph) { | ||||||
|         cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() { |         cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() { | ||||||
| @ -221,8 +221,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, | |||||||
|     bool start_less_end = (start.to<double>() <= end.to<double>()); |     bool start_less_end = (start.to<double>() <= end.to<double>()); | ||||||
|  |  | ||||||
|     @autoreleasepool { |     @autoreleasepool { | ||||||
|       string key = |       string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) + | ||||||
|           "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end); |           std::to_string(start_less_end); | ||||||
|       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); |       auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key); | ||||||
|  |  | ||||||
|       if (!cachedGraph) { |       if (!cachedGraph) { | ||||||
|  | |||||||
| @ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor, | |||||||
|     NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; |     string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; | ||||||
|     string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); |     string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); | ||||||
|     string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + |     string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" + | ||||||
|         keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData); |         keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData); | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); |       newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); | ||||||
| @ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, | |||||||
|     string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; |     string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps"; | ||||||
|     NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; |     string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased "; | ||||||
|     string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0"; |     string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0"; | ||||||
|     string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; |     string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; | ||||||
|     string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" + |     string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" + | ||||||
|         string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value); |         string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value); | ||||||
| @ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t, | |||||||
|   auto stream = at::mps::getCurrentMPSStream(); |   auto stream = at::mps::getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_); |     string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); | ||||||
|       MPSGraphTensor* outputTensor = nil; |       MPSGraphTensor* outputTensor = nil; | ||||||
| @ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t, | |||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string key = |     string key = | ||||||
|         func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); |         func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       auto inputScalarType = input_t.scalar_type(); |       auto inputScalarType = input_t.scalar_type(); | ||||||
|       MPSGraphTensor* inputTensor = |       MPSGraphTensor* inputTensor = | ||||||
| @ -1217,7 +1217,7 @@ TORCH_IMPL_FUNC(any_out_mps) | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     MPSShape* input_t_shape = getMPSShape(input_t); |     MPSShape* input_t_shape = getMPSShape(input_t); | ||||||
|     string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + |     string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + | ||||||
|         getMPSTypeString(input_t); |         getMPSTypeString(input_t); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSDataType input_type = getMPSDataType(input_t); |       MPSDataType input_type = getMPSDataType(input_t); | ||||||
| @ -1313,7 +1313,7 @@ TORCH_IMPL_FUNC(all_out_mps) | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     MPSShape* input_t_shape = getMPSShape(input_t); |     MPSShape* input_t_shape = getMPSShape(input_t); | ||||||
|     string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" + |     string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" + | ||||||
|         getMPSTypeString(input_t); |         getMPSTypeString(input_t); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSDataType input_type = getMPSDataType(input_t); |       MPSDataType input_type = getMPSDataType(input_t); | ||||||
| @ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t, | |||||||
|   auto stream = at::mps::getCurrentMPSStream(); |   auto stream = at::mps::getCurrentMPSStream(); | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = |     string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + | ||||||
|         func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t); |         getTensorsStringKey(indices_t); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); |       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); | ||||||
|       MPSGraphTensor* castInputTensor = |       MPSGraphTensor* castInputTensor = | ||||||
|  | |||||||
| @ -108,8 +108,8 @@ TORCH_IMPL_FUNC(topk_out_mps) | |||||||
|     // Input as placeholders |     // Input as placeholders | ||||||
|     MPSShape* input_shape = getMPSShape(self); |     MPSShape* input_shape = getMPSShape(self); | ||||||
|     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) + |     string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) + | ||||||
|         ":dim" + to_string(dim_) + ":largest" + to_string(largest); |         ":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); |       newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); | ||||||
|  |  | ||||||
| @ -320,12 +320,12 @@ TORCH_IMPL_FUNC(cat_out_mps) | |||||||
|   }; |   }; | ||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = |     string key = "cat_out_mps:" + std::to_string(dimension) + ":" + | ||||||
|         "cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); |         (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); | ||||||
|     if (!all_same_dtype) { |     if (!all_same_dtype) { | ||||||
|       key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); |       key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride); | ||||||
|     } else { |     } else { | ||||||
|       key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size()); |       key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size()); | ||||||
|     } |     } | ||||||
|     for (auto idx : skipped_tensor_indices) { |     for (auto idx : skipped_tensor_indices) { | ||||||
|       key += "," + std::to_string(idx); |       key += "," + std::to_string(idx); | ||||||
|  | |||||||
| @ -60,8 +60,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps) | |||||||
|     // Input as placeholders |     // Input as placeholders | ||||||
|     MPSShape* input_shape = getMPSShape(self); |     MPSShape* input_shape = getMPSShape(self); | ||||||
|     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; |     NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; | ||||||
|     string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) + |     string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + | ||||||
|         ":descending" + to_string(descending); |         std::to_string(dim) + ":descending" + std::to_string(descending); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); |       newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape); | ||||||
|  |  | ||||||
|  | |||||||
| @ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t, | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     // the optional min/max refs could affect how we build the cached graph |     // the optional min/max refs could affect how we build the cached graph | ||||||
|     string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") + |     string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") + | ||||||
|         (has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); |         (has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t}); | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|       if (has_min) |       if (has_min) | ||||||
|         newCachedGraph->minTensor = [mpsGraph |         newCachedGraph->minTensor = [mpsGraph | ||||||
|  | |||||||
| @ -13,32 +13,6 @@ | |||||||
| #include <fmt/format.h> | #include <fmt/format.h> | ||||||
|  |  | ||||||
| namespace at::native { | namespace at::native { | ||||||
| static const std::string& getMetalType(const c10::ScalarType& t) { |  | ||||||
|   // Mapping from c10::ScalarType to integral type that can be used for unary ops |  | ||||||
|   static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = { |  | ||||||
|       {c10::ScalarType::Half, "half"}, |  | ||||||
|       {c10::ScalarType::Float, "float"}, |  | ||||||
|       {c10::ScalarType::Long, "long"}, |  | ||||||
|       {c10::ScalarType::Int, "int"}, |  | ||||||
|       {c10::ScalarType::Short, "short"}, |  | ||||||
|       {c10::ScalarType::Bool, "bool"}, |  | ||||||
|       {c10::ScalarType::Char, "int8_t"}, |  | ||||||
|       {c10::ScalarType::Byte, "uint8_t"}, |  | ||||||
|   }; |  | ||||||
|  |  | ||||||
|   auto it = scalar_to_metal_type.find(t); |  | ||||||
|   TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t); |  | ||||||
|   return it->second; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static const std::string& getMetalType(const c10::Scalar& s) { |  | ||||||
|   return getMetalType(s.type()); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static const std::string& getMetalType(const Tensor& t) { |  | ||||||
|   return getMetalType(t.scalar_type()); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2); | static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2); | ||||||
|  |  | ||||||
| TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { | TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { | ||||||
| @ -57,7 +31,8 @@ TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) { | |||||||
|   } |   } | ||||||
|   using namespace mps; |   using namespace mps; | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)}); |     auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", | ||||||
|  |                                                 {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)}); | ||||||
|  |  | ||||||
|     if (!self.is_contiguous()) { |     if (!self.is_contiguous()) { | ||||||
|       inputTensor = inputTensor.contiguous(); |       inputTensor = inputTensor.contiguous(); | ||||||
|  | |||||||
| @ -36,8 +36,8 @@ static std::string getUniqueKey(const ScalarType& dtype, | |||||||
|                                 const bool consecutive, |                                 const bool consecutive, | ||||||
|                                 c10::optional<int64_t> dimOpt) { |                                 c10::optional<int64_t> dimOpt) { | ||||||
|   return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" + |   return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" + | ||||||
|       (dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" + |       (dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" + | ||||||
|       to_string(return_counts) + "]:[" + to_string(consecutive) + "]"; |       std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]"; | ||||||
| } | } | ||||||
|  |  | ||||||
| // dim arg not supported when non consecutive, ie sorted | // dim arg not supported when non consecutive, ie sorted | ||||||
|  | |||||||
| @ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input, | |||||||
|  |  | ||||||
|   @autoreleasepool { |   @autoreleasepool { | ||||||
|     string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + |     string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") + | ||||||
|         getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" + |         getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" + | ||||||
|         (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; |         (is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]"; | ||||||
|  |  | ||||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||||
|  | |||||||
| @ -42,7 +42,7 @@ static std::string getStridedKey(const ScalarType& self_dtype, | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" + |   return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" + | ||||||
|       getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]"; |       getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]"; | ||||||
| } | } | ||||||
|  |  | ||||||
| // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op | // initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op | ||||||
|  | |||||||
| @ -764,8 +764,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c | |||||||
|   const int64_t batch_size = query.size(0); |   const int64_t batch_size = query.size(0); | ||||||
|   const int64_t num_heads = query.size(1); |   const int64_t num_heads = query.size(1); | ||||||
|   const int64_t max_seqlen_batch_q = query.size(2); |   const int64_t max_seqlen_batch_q = query.size(2); | ||||||
|   const int64_t head_dim = query.size(3); |   const int64_t head_dim_qk = query.size(3); | ||||||
|  |   const int64_t head_dim_v = value.size(3); | ||||||
|   const int64_t max_seqlen_batch_k = key.size(2); |   const int64_t max_seqlen_batch_k = key.size(2); | ||||||
|   const int64_t max_seqlen_batch_v = value.size(2); |   const int64_t max_seqlen_batch_v = value.size(2); | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
| @ -806,7 +806,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c | |||||||
|                       num_heads/*int64_t h*/, |                       num_heads/*int64_t h*/, | ||||||
|                       max_seqlen_batch_q/*int64_t s_q*/, |                       max_seqlen_batch_q/*int64_t s_q*/, | ||||||
|                       max_seqlen_batch_k/*int64_t s_kv*/, |                       max_seqlen_batch_k/*int64_t s_kv*/, | ||||||
|                       head_dim/*int64_t d*/, |                       head_dim_qk/*int64_t d_qk*/, | ||||||
|  |                       head_dim_v/*int64_t d_v*/, | ||||||
|                       softmax_scale/*float scaling_factor*/, |                       softmax_scale/*float scaling_factor*/, | ||||||
|                       compute_logsumexp/* bool */, |                       compute_logsumexp/* bool */, | ||||||
|                       is_causal/* bool */, |                       is_causal/* bool */, | ||||||
|  | |||||||
| @ -194,12 +194,11 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_ | |||||||
|  |  | ||||||
|     const int64_t batch_size = query.size(0); |     const int64_t batch_size = query.size(0); | ||||||
|     const int64_t num_heads = query.size(1); |     const int64_t num_heads = query.size(1); | ||||||
|     const int64_t head_dim = query.size(3); |     const int64_t head_dim_qk = query.size(3); | ||||||
|  |     const int64_t head_dim_v = value.size(3); | ||||||
|     const int64_t max_seqlen_batch_q = query.size(1); |     const int64_t max_seqlen_batch_q = query.size(1); | ||||||
|     const int64_t max_seqlen_batch_k = key.size(1); |     const int64_t max_seqlen_batch_k = key.size(1); | ||||||
|  |  | ||||||
|     const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); |     const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); | ||||||
|  |  | ||||||
|     auto dq = at::empty_like(query); |     auto dq = at::empty_like(query); | ||||||
|     auto dk = at::empty_like(key); |     auto dk = at::empty_like(key); | ||||||
|     auto dv = at::empty_like(value); |     auto dv = at::empty_like(value); | ||||||
| @ -207,7 +206,8 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_ | |||||||
|                         num_heads /*int64_t h*/, |                         num_heads /*int64_t h*/, | ||||||
|                         max_seqlen_batch_q /*int64_t s_q*/, |                         max_seqlen_batch_q /*int64_t s_q*/, | ||||||
|                         max_seqlen_batch_k /*int64_t s_kv*/, |                         max_seqlen_batch_k /*int64_t s_kv*/, | ||||||
|                         head_dim /*int64_t d*/, |                         head_dim_qk /*int64_t d_qk*/, | ||||||
|  |                         head_dim_v /*int64_t d_v*/, | ||||||
|                         softmax_scale /*float scaling_factor*/, |                         softmax_scale /*float scaling_factor*/, | ||||||
|                         is_causal /*bool is_causal*/, |                         is_causal /*bool is_causal*/, | ||||||
|                         dropout_p /*float dropout_probability*/, |                         dropout_p /*float dropout_probability*/, | ||||||
|  | |||||||
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,46 | hf_BigBird,pass,43 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass, 52 | hf_BigBird,pass,49 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,46 | hf_BigBird,pass,43 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,52 | hf_BigBird,pass,49 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,fail_accuracy,46 | hf_BigBird,fail_accuracy,43 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,52 | hf_BigBird,pass,49 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,46 | hf_BigBird,pass,43 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,52 | hf_BigBird,pass,49 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForCausalLM,pass,12 | BartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BartForConditionalGeneration,pass,24 | BartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForCausalLM,pass,12 | BlenderbotSmallForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| BlenderbotSmallForConditionalGeneration,pass,24 | BlenderbotSmallForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForCausalLM,pass,12 | MBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| MBartForConditionalGeneration,pass,24 | MBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| OPTForCausalLM,pass,12 | OPTForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForCausalLM,pass,12 | PLBartForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PLBartForConditionalGeneration,pass,29 | PLBartForConditionalGeneration,pass,8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForCausalLM,pass,12 | PegasusForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| PegasusForConditionalGeneration,pass,23 | PegasusForConditionalGeneration,pass,7 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| Speech2Text2ForCausalLM,pass,12 | Speech2Text2ForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| TrOCRForCausalLM,pass,12 | TrOCRForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| XGLMForCausalLM,pass,12 | XGLMForCausalLM,pass,6 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,fail_accuracy,46 | hf_BigBird,fail_accuracy,43 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | |||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| hf_BigBird,pass,52 | hf_BigBird,pass,49 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| 
 | 
| @ -272,6 +272,38 @@ TEST(StaticRuntime, autogen_addr) { | |||||||
|       /*check_resize=*/true); |       /*check_resize=*/true); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | TEST(StaticRuntime, autogen__test_functorch_fallback) { | ||||||
|  |   const std::string script = R"IR( | ||||||
|  |     graph(%self: Tensor, %other: Tensor): | ||||||
|  |         %bias: None = prim::Constant() | ||||||
|  |         %ret = aten::_test_functorch_fallback(%self, %other) | ||||||
|  |         %cloned = aten::clone(%ret, %bias) | ||||||
|  |         return (%cloned) | ||||||
|  |   )IR"; | ||||||
|  |  | ||||||
|  |   auto self0 = at::rand({6, 6, 6}); | ||||||
|  |   auto other0 = at::rand({6, 6, 6}); | ||||||
|  |   std::vector<IValue> args{self0, other0}; | ||||||
|  |   testStaticRuntime( | ||||||
|  |       script, | ||||||
|  |       args, | ||||||
|  |       {}, | ||||||
|  |       /*use_allclose=*/false, | ||||||
|  |       /*use_equalnan=*/false, | ||||||
|  |       /*check_resize=*/true); | ||||||
|  |  | ||||||
|  |   auto self1 = at::rand({22, 22, 22}); | ||||||
|  |   auto other1 = at::rand({22, 22, 22}); | ||||||
|  |   std::vector<IValue> args2{self1, other1}; | ||||||
|  |   testStaticRuntime( | ||||||
|  |       script, | ||||||
|  |       args, | ||||||
|  |       args2, | ||||||
|  |       /*use_allclose=*/false, | ||||||
|  |       /*use_equalnan=*/false, | ||||||
|  |       /*check_resize=*/true); | ||||||
|  | } | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_argmax) { | TEST(StaticRuntime, autogen_argmax) { | ||||||
|   const std::string script = R"IR( |   const std::string script = R"IR( | ||||||
|     graph(%self: Tensor, %dim: int?, %keepdim: bool): |     graph(%self: Tensor, %dim: int?, %keepdim: bool): | ||||||
| @ -4440,6 +4472,40 @@ TEST(StaticRuntime, autogen_masked_select) { | |||||||
|       /*check_resize=*/true); |       /*check_resize=*/true); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | TEST(StaticRuntime, autogen_nonzero_static) { | ||||||
|  |   const std::string script = R"IR( | ||||||
|  |     graph(%self: Tensor, %size: int, %fill_value: int): | ||||||
|  |         %bias: None = prim::Constant() | ||||||
|  |         %ret = aten::nonzero_static(%self, %size, %fill_value) | ||||||
|  |         %cloned = aten::clone(%ret, %bias) | ||||||
|  |         return (%cloned) | ||||||
|  |   )IR"; | ||||||
|  |  | ||||||
|  |   auto self0 = at::rand({6, 6, 6}); | ||||||
|  |   auto size0 = 1; | ||||||
|  |   auto fill_value0 = 1; | ||||||
|  |   std::vector<IValue> args{self0, size0, fill_value0}; | ||||||
|  |   testStaticRuntime( | ||||||
|  |       script, | ||||||
|  |       args, | ||||||
|  |       {}, | ||||||
|  |       /*use_allclose=*/false, | ||||||
|  |       /*use_equalnan=*/false, | ||||||
|  |       /*check_resize=*/true); | ||||||
|  |  | ||||||
|  |   auto self1 = at::rand({22, 22, 22}); | ||||||
|  |   auto size1 = 1; | ||||||
|  |   auto fill_value1 = 1; | ||||||
|  |   std::vector<IValue> args2{self1, size1, fill_value1}; | ||||||
|  |   testStaticRuntime( | ||||||
|  |       script, | ||||||
|  |       args, | ||||||
|  |       args2, | ||||||
|  |       /*use_allclose=*/false, | ||||||
|  |       /*use_equalnan=*/false, | ||||||
|  |       /*check_resize=*/true); | ||||||
|  | } | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_gather) { | TEST(StaticRuntime, autogen_gather) { | ||||||
|   const std::string script = R"IR( |   const std::string script = R"IR( | ||||||
|     graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool): |     graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool): | ||||||
| @ -7106,222 +7172,6 @@ TEST(StaticRuntime, autogen_special_multigammaln) { | |||||||
|       /*check_resize=*/true); |       /*check_resize=*/true); | ||||||
| } | } | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_fft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_fft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_ifft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_ifft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_rfft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_rfft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_irfft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_irfft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_hfft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_hfft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_fft_ihfft) { |  | ||||||
|   const std::string script = R"IR( |  | ||||||
|     graph(%self: Tensor, %n: int?, %dim: int, %norm: str?): |  | ||||||
|         %bias: None = prim::Constant() |  | ||||||
|         %ret = aten::fft_ihfft(%self, %n, %dim, %norm) |  | ||||||
|         %cloned = aten::clone(%ret, %bias) |  | ||||||
|         return (%cloned) |  | ||||||
|   )IR"; |  | ||||||
|  |  | ||||||
|   auto self0 = at::rand({6, 6, 6}); |  | ||||||
|   auto n0 = 1; |  | ||||||
|   auto dim0 = 1; |  | ||||||
|   auto norm0 = "forward"; |  | ||||||
|   std::vector<IValue> args{self0, n0, dim0, norm0}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       {}, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
|  |  | ||||||
|   auto self1 = at::rand({22, 22, 22}); |  | ||||||
|   auto n1 = 1; |  | ||||||
|   auto dim1 = 1; |  | ||||||
|   auto norm1 = "forward"; |  | ||||||
|   std::vector<IValue> args2{self1, n1, dim1, norm1}; |  | ||||||
|   testStaticRuntime( |  | ||||||
|       script, |  | ||||||
|       args, |  | ||||||
|       args2, |  | ||||||
|       /*use_allclose=*/false, |  | ||||||
|       /*use_equalnan=*/false, |  | ||||||
|       /*check_resize=*/true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TEST(StaticRuntime, autogen_linalg_cross) { | TEST(StaticRuntime, autogen_linalg_cross) { | ||||||
|   const std::string script = R"IR( |   const std::string script = R"IR( | ||||||
|     graph(%self: Tensor, %other: Tensor, %dim: int): |     graph(%self: Tensor, %other: Tensor, %dim: int): | ||||||
|  | |||||||
| @ -779,4 +779,5 @@ Tensor class reference | |||||||
|     Tensor.where |     Tensor.where | ||||||
|     Tensor.xlogy |     Tensor.xlogy | ||||||
|     Tensor.xlogy_ |     Tensor.xlogy_ | ||||||
|  |     Tensor.xpu | ||||||
|     Tensor.zero_ |     Tensor.zero_ | ||||||
|  | |||||||
| @ -80,6 +80,48 @@ class WorkerServerTest(TestCase): | |||||||
|             resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") |             resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") | ||||||
|             self.assertEqual(resp.status, 200) |             self.assertEqual(resp.status, 200) | ||||||
|             out = pickle.loads(resp.data) |             out = pickle.loads(resp.data) | ||||||
|  |             self.assertIsInstance(out, dict) | ||||||
|  |             self.assertIn("version", out) | ||||||
|  |  | ||||||
|  |     @requires_cuda | ||||||
|  |     def test_dump_nccl_trace_pickle_with_params(self) -> None: | ||||||
|  |         with local_worker_server() as pool: | ||||||
|  |             # bad key - not lower case | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 400) | ||||||
|  |             # unknown key | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 400) | ||||||
|  |             # bad value - not a bool | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 400) | ||||||
|  |             # bad value - value not lowercase | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 400) | ||||||
|  |             # good key and value | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 200) | ||||||
|  |             # good key and value | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", "/handler/dump_nccl_trace_pickle?includestacktraces=true" | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 200) | ||||||
|  |             # multiple good keys and values | ||||||
|  |             resp = pool.request( | ||||||
|  |                 "POST", | ||||||
|  |                 "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", | ||||||
|  |             ) | ||||||
|  |             self.assertEqual(resp.status, 200) | ||||||
|  |  | ||||||
|     def test_tcp(self) -> None: |     def test_tcp(self) -> None: | ||||||
|         import requests |         import requests | ||||||
|  | |||||||
| @ -13,7 +13,6 @@ import shutil | |||||||
| import subprocess | import subprocess | ||||||
| import sys | import sys | ||||||
| import tempfile | import tempfile | ||||||
| import unittest |  | ||||||
| import uuid | import uuid | ||||||
| from contextlib import closing | from contextlib import closing | ||||||
| from unittest import mock | from unittest import mock | ||||||
| @ -23,12 +22,13 @@ import torch.distributed.run as launch | |||||||
| from torch.distributed.elastic.agent.server.api import RunResult, WorkerState | from torch.distributed.elastic.agent.server.api import RunResult, WorkerState | ||||||
| from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs | from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs | ||||||
| from torch.distributed.elastic.multiprocessing.errors import ChildFailedError | from torch.distributed.elastic.multiprocessing.errors import ChildFailedError | ||||||
| from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer |  | ||||||
| from torch.distributed.elastic.utils import get_socket_with_port | from torch.distributed.elastic.utils import get_socket_with_port | ||||||
| from torch.distributed.elastic.utils.distributed import get_free_port | from torch.distributed.elastic.utils.distributed import get_free_port | ||||||
| from torch.testing._internal.common_utils import ( | from torch.testing._internal.common_utils import ( | ||||||
|  |     run_tests, | ||||||
|     skip_but_pass_in_sandcastle_if, |     skip_but_pass_in_sandcastle_if, | ||||||
|     TEST_WITH_DEV_DBG_ASAN, |     TEST_WITH_DEV_DBG_ASAN, | ||||||
|  |     TestCase, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -63,19 +63,7 @@ class MockException(Exception): | |||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ElasticLaunchTest(unittest.TestCase): | class ElasticLaunchTest(TestCase): | ||||||
|     @classmethod |  | ||||||
|     def setUpClass(cls): |  | ||||||
|         # start a standalone, single process etcd server to use for all tests |  | ||||||
|         cls._etcd_server = EtcdServer() |  | ||||||
|         cls._etcd_server.start() |  | ||||||
|         cls._etcd_endpoint = cls._etcd_server.get_endpoint() |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def tearDownClass(cls): |  | ||||||
|         # stop the standalone etcd server |  | ||||||
|         cls._etcd_server.stop() |  | ||||||
| 
 |  | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         self.test_dir = tempfile.mkdtemp() |         self.test_dir = tempfile.mkdtemp() | ||||||
| 
 | 
 | ||||||
| @ -103,8 +91,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={nnodes}", |             f"--nnodes={nnodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |  | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |  | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -156,8 +142,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={nnodes}", |             f"--nnodes={nnodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |  | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |  | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -187,8 +171,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         world_size = 1 |         world_size = 1 | ||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={nnodes}", |             f"--nnodes={nnodes}", | ||||||
|             "--rdzv-backend=etcd", |  | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |  | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -220,8 +202,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
| 
 | 
 | ||||||
|         os.environ["PET_NNODES"] = str(nnodes) |         os.environ["PET_NNODES"] = str(nnodes) | ||||||
|         os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) |         os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node) | ||||||
|         os.environ["PET_RDZV_BACKEND"] = "etcd" |  | ||||||
|         os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint |  | ||||||
|         os.environ["PET_RDZV_ID"] = run_id |         os.environ["PET_RDZV_ID"] = run_id | ||||||
|         os.environ["PET_MONITOR_INTERVAL"] = "1" |         os.environ["PET_MONITOR_INTERVAL"] = "1" | ||||||
|         os.environ["PET_START_METHOD"] = "spawn" |         os.environ["PET_START_METHOD"] = "spawn" | ||||||
| @ -250,8 +230,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={nnodes}", |             f"--nnodes={nnodes}", | ||||||
|             f"--nproc-per-node={nproc_type}", |             f"--nproc-per-node={nproc_type}", | ||||||
|             "--rdzv-backend=etcd", |  | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |  | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -272,7 +250,8 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|     @skip_but_pass_in_sandcastle_if( |     @skip_but_pass_in_sandcastle_if( | ||||||
|         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" |         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" | ||||||
|     ) |     ) | ||||||
|     def test_nproc_launch_auto_configurations(self): |     @patch("torch.cuda.is_available", return_value=False) | ||||||
|  |     def test_nproc_launch_auto_configurations(self, _mock1): | ||||||
|         self._test_nproc_launch_configuration("auto", os.cpu_count()) |         self._test_nproc_launch_configuration("auto", os.cpu_count()) | ||||||
| 
 | 
 | ||||||
|     @skip_but_pass_in_sandcastle_if( |     @skip_but_pass_in_sandcastle_if( | ||||||
| @ -310,8 +289,9 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={min_nodes}:{max_nodes}", |             f"--nnodes={min_nodes}:{max_nodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |             "--rdzv-backend=c10d", | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |             f"--rdzv-endpoint=localhost:{get_free_port()}", | ||||||
|  |             "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -343,8 +323,9 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={min_nodes}:{max_nodes}", |             f"--nnodes={min_nodes}:{max_nodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |             "--rdzv-backend=c10d", | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |             f"--rdzv-endpoint=localhost:{get_free_port()}", | ||||||
|  |             "--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'", | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--max-restarts=0", |             "--max-restarts=0", | ||||||
| @ -376,8 +357,9 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={min_nodes}:{max_nodes}", |             f"--nnodes={min_nodes}:{max_nodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |             "--rdzv-backend=c10d", | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |             f"--rdzv-endpoint=localhost:{get_free_port()}", | ||||||
|  |             "--rdzv_conf=timeout=5", | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--max-restarts=0", |             "--max-restarts=0", | ||||||
| @ -452,8 +434,9 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         args = [ |         args = [ | ||||||
|             f"--nnodes={min_nodes}:{max_nodes}", |             f"--nnodes={min_nodes}:{max_nodes}", | ||||||
|             f"--nproc-per-node={nproc_per_node}", |             f"--nproc-per-node={nproc_per_node}", | ||||||
|             "--rdzv-backend=etcd", |             "--rdzv-backend=c10d", | ||||||
|             f"--rdzv-endpoint={self._etcd_endpoint}", |             f"--rdzv-endpoint=localhost:{get_free_port()}", | ||||||
|  |             "--rdzv_conf=timeout=5", | ||||||
|             f"--rdzv-id={run_id}", |             f"--rdzv-id={run_id}", | ||||||
|             "--monitor-interval=1", |             "--monitor-interval=1", | ||||||
|             "--start-method=spawn", |             "--start-method=spawn", | ||||||
| @ -608,21 +591,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|                 is_torchelastic_launched = fp.readline() |                 is_torchelastic_launched = fp.readline() | ||||||
|                 self.assertEqual("False", is_torchelastic_launched) |                 self.assertEqual("False", is_torchelastic_launched) | ||||||
| 
 | 
 | ||||||
|     def test_init_method_tcp(self): |  | ||||||
|         port = get_free_port() |  | ||||||
|         with patch.object( |  | ||||||
|             sys, |  | ||||||
|             "argv", |  | ||||||
|             [ |  | ||||||
|                 path("bin/test_script_init_method.py"), |  | ||||||
|                 f"--init-method=tcp://localhost:{port}", |  | ||||||
|                 "--rank=0", |  | ||||||
|                 "--world-size=1", |  | ||||||
|             ], |  | ||||||
|         ): |  | ||||||
|             runpy.run_path(sys.argv[0], run_name="__main__") |  | ||||||
|             # nothing to validate, just make sure it runs |  | ||||||
| 
 |  | ||||||
|     @skip_but_pass_in_sandcastle_if( |     @skip_but_pass_in_sandcastle_if( | ||||||
|         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" |         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" | ||||||
|     ) |     ) | ||||||
| @ -642,27 +610,6 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|         ) |         ) | ||||||
|         # nothing to validate, just make sure it runs |         # nothing to validate, just make sure it runs | ||||||
| 
 | 
 | ||||||
|     def test_init_method_env(self): |  | ||||||
|         port = get_free_port() |  | ||||||
|         with patch.dict( |  | ||||||
|             os.environ, |  | ||||||
|             { |  | ||||||
|                 "RANK": "0", |  | ||||||
|                 "WORLD_SIZE": "1", |  | ||||||
|                 "MASTER_ADDR": "localhost", |  | ||||||
|                 "MASTER_PORT": str(port), |  | ||||||
|             }, |  | ||||||
|         ), patch.object( |  | ||||||
|             sys, |  | ||||||
|             "argv", |  | ||||||
|             [ |  | ||||||
|                 path("bin/test_script_init_method.py"), |  | ||||||
|                 "--init-method=env://", |  | ||||||
|             ], |  | ||||||
|         ): |  | ||||||
|             runpy.run_path(sys.argv[0], run_name="__main__") |  | ||||||
|             # nothing to validate, just make sure it runs |  | ||||||
| 
 |  | ||||||
|     @skip_but_pass_in_sandcastle_if( |     @skip_but_pass_in_sandcastle_if( | ||||||
|         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" |         TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" | ||||||
|     ) |     ) | ||||||
| @ -681,3 +628,7 @@ class ElasticLaunchTest(unittest.TestCase): | |||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
|         # nothing to validate, just make sure it runs |         # nothing to validate, just make sure it runs | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     run_tests() | ||||||
| @ -3662,7 +3662,8 @@ class NCCLTraceTest(NCCLTraceTestBase): | |||||||
|     @requires_nccl() |     @requires_nccl() | ||||||
|     @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") |     @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") | ||||||
|     @parametrize("timing_enabled", [True, False]) |     @parametrize("timing_enabled", [True, False]) | ||||||
|     def test_trace_while_active(self, timing_enabled): |     @parametrize("only_active", [True, False]) | ||||||
|  |     def test_trace_while_active(self, timing_enabled, only_active): | ||||||
|         if self.rank == self.MAIN_PROCESS_RANK: |         if self.rank == self.MAIN_PROCESS_RANK: | ||||||
|             for c in self.children_pipes: |             for c in self.children_pipes: | ||||||
|                 self.assertEqual(c.recv(), "next") |                 self.assertEqual(c.recv(), "next") | ||||||
| @ -3683,17 +3684,26 @@ class NCCLTraceTest(NCCLTraceTestBase): | |||||||
|             if self.rank != 0: |             if self.rank != 0: | ||||||
|                 pg.allreduce(a).wait() |                 pg.allreduce(a).wait() | ||||||
|             e.synchronize() |             e.synchronize() | ||||||
|             t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) |             t = pickle.loads( | ||||||
|  |                 torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active) | ||||||
|  |             ) | ||||||
|             t = t["entries"] |             t = t["entries"] | ||||||
|             self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") |             if only_active: | ||||||
|             if self.rank == 0: |                 if self.rank == 0: | ||||||
|                 self.assertEqual(t[-1]["collective_seq_id"], 1) |                     self.assertEqual(len(t), 0) | ||||||
|                 self.assertEqual(t[-1]["state"], "completed") |                 else: | ||||||
|             else: |                     self.assertEqual(len(t), 1) | ||||||
|                 self.assertEqual(t[-1]["collective_seq_id"], 2) |             if not only_active: | ||||||
|                 self.assertEqual( |                 if self.rank == 0: | ||||||
|                     t[-1]["state"], self.started_or_scheduled(timing_enabled) |                     self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") | ||||||
|                 ) |                     self.assertEqual(t[-1]["collective_seq_id"], 1) | ||||||
|  |                     self.assertEqual(t[-1]["state"], "completed") | ||||||
|  |                 else: | ||||||
|  |                     self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce") | ||||||
|  |                     self.assertEqual(t[-1]["collective_seq_id"], 2) | ||||||
|  |                     self.assertEqual( | ||||||
|  |                         t[-1]["state"], self.started_or_scheduled(timing_enabled) | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|             self.parent.send("next") |             self.parent.send("next") | ||||||
|             self.assertEqual("next", self.parent.recv()) |             self.assertEqual("next", self.parent.recv()) | ||||||
|  | |||||||
| @ -1084,14 +1084,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase): | |||||||
|             # far from an exhaustive check of all the expected guards, just check a couple of them. |             # far from an exhaustive check of all the expected guards, just check a couple of them. | ||||||
|             FileCheck().check("""local "L['self']" TYPE_MATCH""").check( |             FileCheck().check("""local "L['self']" TYPE_MATCH""").check( | ||||||
|                 """local "L['self']" ID_MATCH""" |                 """local "L['self']" ID_MATCH""" | ||||||
|  |             ).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check( | ||||||
|  |                 f"""{expected_guard_source} "L['self'].net" ID_MATCH""" | ||||||
|             ).check( |             ).check( | ||||||
|                 f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" |                 f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH""" | ||||||
|             ).check( |             ).check( | ||||||
|                 f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH""" |                 f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH""" | ||||||
|             ).check( |  | ||||||
|                 f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" |  | ||||||
|             ).check( |  | ||||||
|                 f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH""" |  | ||||||
|             ).run( |             ).run( | ||||||
|                 GUARDS_FILE.getvalue() |                 GUARDS_FILE.getvalue() | ||||||
|             ) |             ) | ||||||
|  | |||||||
| @ -464,6 +464,44 @@ class DecoratorTests(torch._dynamo.test_case.TestCase): | |||||||
|  |  | ||||||
|         self.assertEqual(cnt.frame_count, 1) |         self.assertEqual(cnt.frame_count, 1) | ||||||
|  |  | ||||||
|  |     def test_assume_constant_result_on_user_defined_fn(self): | ||||||
|  |         @torch._dynamo.assume_constant_result | ||||||
|  |         def const_fn(n, s): | ||||||
|  |             return torch.full([n], s) | ||||||
|  |  | ||||||
|  |         def fn(B): | ||||||
|  |             B = const_fn(B.size(0), 13) | ||||||
|  |             X = B * 2 | ||||||
|  |             return X.tolist() | ||||||
|  |  | ||||||
|  |         B_list = [8] * 32 | ||||||
|  |  | ||||||
|  |         B = torch.tensor(B_list, dtype=torch.int32) | ||||||
|  |         torch._dynamo.decorators.mark_static(B, 0) | ||||||
|  |  | ||||||
|  |         torch._dynamo.config.capture_scalar_outputs = True | ||||||
|  |         torch._dynamo.config.capture_dynamic_output_shape_ops = True | ||||||
|  |  | ||||||
|  |         self.assertEqual( | ||||||
|  |             fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_assume_constant_result_on_computation_with_graph_input(self): | ||||||
|  |         @torch._dynamo.assume_constant_result | ||||||
|  |         def check(y): | ||||||
|  |             return y[0].item() == 1 | ||||||
|  |  | ||||||
|  |         def fn(x, y): | ||||||
|  |             if check(y): | ||||||
|  |                 return x + 2 | ||||||
|  |             else: | ||||||
|  |                 return x + 1 | ||||||
|  |  | ||||||
|  |         y = torch.tensor([1]) | ||||||
|  |         x = torch.tensor(1) | ||||||
|  |  | ||||||
|  |         self.assertEqual(fn(x, y), torch.compile(fn)(x, y)) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     from torch._dynamo.test_case import run_tests |     from torch._dynamo.test_case import run_tests | ||||||
|  | |||||||
| @ -253,6 +253,7 @@ Target Expressions: | |||||||
|   ==> (>= 0 s1) |   ==> (>= 0 s1) | ||||||
|   ==> (>= 0 s2) |   ==> (>= 0 s2) | ||||||
|   ==> (>= 0 s3) |   ==> (>= 0 s3) | ||||||
|  |   ==> (>= 9223372036854775806 s0) | ||||||
|  |  | ||||||
| Failed Source Expressions: | Failed Source Expressions: | ||||||
|   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", |   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", | ||||||
| @ -286,14 +287,14 @@ Failure occurred while running node: | |||||||
| Model: | Model: | ||||||
|   ==> L['shape'][0]: 1 |   ==> L['shape'][0]: 1 | ||||||
|   ==> L['shape'][1]: 1 |   ==> L['shape'][1]: 1 | ||||||
|   ==> L['shape'][2]: 0 |   ==> L['shape'][2]: 2 | ||||||
|   ==> L['x'].size()[0]: 3 |   ==> L['x'].size()[0]: 3 | ||||||
|   ==> L['x'].storage_offset(): 0 |   ==> L['x'].storage_offset(): 0 | ||||||
|   ==> L['x'].stride()[0]: 1 |   ==> L['x'].stride()[0]: 1 | ||||||
|   ==> s0: 3 |   ==> s0: 3 | ||||||
|   ==> s1: 1 |   ==> s1: 1 | ||||||
|   ==> s2: 1 |   ==> s2: 1 | ||||||
|   ==> s3: 0 |   ==> s3: 2 | ||||||
|  |  | ||||||
| Assertions: | Assertions: | ||||||
|   ==> (== 0 L['x'].storage_offset()) |   ==> (== 0 L['x'].storage_offset()) | ||||||
| @ -317,6 +318,10 @@ Target Expressions: | |||||||
|   ==> (== L['shape'][2] s3) |   ==> (== L['shape'][2] s3) | ||||||
|   ==> (== L['x'].size()[0] s0) |   ==> (== L['x'].size()[0] s0) | ||||||
|   ==> (> s0 0) |   ==> (> s0 0) | ||||||
|  |   ==> (>= 9223372036854775806 s0) | ||||||
|  |   ==> (>= 9223372036854775807 s1) | ||||||
|  |   ==> (>= 9223372036854775807 s2) | ||||||
|  |   ==> (>= 9223372036854775807 s3) | ||||||
|  |  | ||||||
| Failed Source Expressions: | Failed Source Expressions: | ||||||
|   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", |   ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", | ||||||
|  | |||||||
| @ -3473,6 +3473,7 @@ class GraphModule(torch.nn.Module): | |||||||
|         ] |         ] | ||||||
|         false_guard_code = [ |         false_guard_code = [ | ||||||
|             "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", |             "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", | ||||||
|  |             "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])", | ||||||
|         ] |         ] | ||||||
|         test_symbool_guards( |         test_symbool_guards( | ||||||
|             f, |             f, | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ import enum | |||||||
| import functools | import functools | ||||||
| import pprint | import pprint | ||||||
| import re | import re | ||||||
| import sys |  | ||||||
| import unittest | import unittest | ||||||
| import warnings | import warnings | ||||||
|  |  | ||||||
| @ -2860,7 +2859,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) |         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) | ||||||
|  |  | ||||||
|         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True);  _add_batch_dim_1 = None |         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim_1 = None | ||||||
|         batched_outputs = _autograd_grad[0];  _autograd_grad = None |         batched_outputs = _autograd_grad[0];  _autograd_grad = None | ||||||
|  |  | ||||||
|         chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0);  batched_outputs = None |         chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0);  batched_outputs = None | ||||||
| @ -2896,7 +2895,7 @@ class GraphModule(torch.nn.Module): | |||||||
|         jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0];  split_2 = None |         jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0];  split_2 = None | ||||||
|  |  | ||||||
|         unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3));  jac_out_in = None |         unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3));  jac_out_in = None | ||||||
|         return (unflatten, diff_primals, o) |         return (unflatten,) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -2964,8 +2963,8 @@ class GraphModule(torch.nn.Module): | |||||||
|         _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") |         _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.") | ||||||
|         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() |         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting() | ||||||
|  |  | ||||||
|         _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3) |         _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None | ||||||
|         child_4 = torch._C._functorch._wrap_for_grad(child_3, 3) |         child_4 = torch._C._functorch._wrap_for_grad(child_3, 3);  child_3 = None | ||||||
|  |  | ||||||
|         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) |         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True) | ||||||
|  |  | ||||||
| @ -3002,7 +3001,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) |         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1) | ||||||
|  |  | ||||||
|         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True);  _add_batch_dim_1 = None |         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = child_4 = _add_batch_dim_1 = None | ||||||
|         child_5 = _autograd_grad[0];  _autograd_grad = None |         child_5 = _autograd_grad[0];  _autograd_grad = None | ||||||
|  |  | ||||||
|         child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0);  child_5 = None |         child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0);  child_5 = None | ||||||
| @ -3041,17 +3040,10 @@ class GraphModule(torch.nn.Module): | |||||||
|         unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4));  jac_out_in = None""", |         unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4));  jac_out_in = None""", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         # Python 3.10 and 3.11 produces slightly different graphs |         self.assertExpectedInline( | ||||||
|         if sys.version_info[:2] > (3, 10): |             actual.split("\n")[-2], | ||||||
|             self.assertExpectedInline( |             """        return (unflatten,)""", | ||||||
|                 actual.split("\n")[-2], |         ) | ||||||
|                 """        return (unflatten, child_2, _wrap_for_grad_1, child_3, child_4, o)""", |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             self.assertExpectedInline( |  | ||||||
|                 actual.split("\n")[-2], |  | ||||||
|                 """        return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""", |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     @unittest.expectedFailure |     @unittest.expectedFailure | ||||||
|     def test_hessian_disable_capture(self): |     def test_hessian_disable_capture(self): | ||||||
| @ -3160,7 +3152,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) |         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) | ||||||
|  |  | ||||||
|         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  _add_batch_dim = None |         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None | ||||||
|         batched_outputs = _autograd_grad[0];  _autograd_grad = None |         batched_outputs = _autograd_grad[0];  _autograd_grad = None | ||||||
|  |  | ||||||
|         chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None |         chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None | ||||||
| @ -3172,7 +3164,7 @@ class GraphModule(torch.nn.Module): | |||||||
|         split_1: "f32[12, 4, 3]" = split[0];  split = None |         split_1: "f32[12, 4, 3]" = split[0];  split = None | ||||||
|  |  | ||||||
|         output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3));  split_1 = None |         output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3));  split_1 = None | ||||||
|         return (output_input, diff_primals, o) |         return (output_input,) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -3243,7 +3235,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) |         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) | ||||||
|  |  | ||||||
|         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  _add_batch_dim = None |         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None | ||||||
|         batched_outputs = _autograd_grad[0];  _autograd_grad = None |         batched_outputs = _autograd_grad[0];  _autograd_grad = None | ||||||
|  |  | ||||||
|         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None |         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None | ||||||
| @ -3255,7 +3247,7 @@ class GraphModule(torch.nn.Module): | |||||||
|         split_1: "f32[12, 3, 4]" = split[0];  split = None |         split_1: "f32[12, 3, 4]" = split[0];  split = None | ||||||
|  |  | ||||||
|         output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None |         output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None | ||||||
|         return (output_input, diff_primals, o) |         return (output_input,) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -3328,7 +3320,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) |         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim) | ||||||
|  |  | ||||||
|         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  _add_batch_dim = None |         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None | ||||||
|         batched_outputs = _autograd_grad[0];  _autograd_grad = None |         batched_outputs = _autograd_grad[0];  _autograd_grad = None | ||||||
|  |  | ||||||
|         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None |         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None | ||||||
| @ -3340,7 +3332,7 @@ class GraphModule(torch.nn.Module): | |||||||
|         split_1: "f32[12, 3, 4]" = split[0];  split = None |         split_1: "f32[12, 3, 4]" = split[0];  split = None | ||||||
|  |  | ||||||
|         output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None |         output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4));  split_1 = None | ||||||
|         return (output_input, aux_1, diff_primals, o) |         return (output_input, aux_1) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -3776,7 +3768,7 @@ class GraphModule(torch.nn.Module): | |||||||
|  |  | ||||||
|         _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() |         _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting() | ||||||
|         _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() |         _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable() | ||||||
|         return (grad_input_1, y) |         return (y, grad_input_1) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -5187,10 +5179,10 @@ class GraphModule(torch.nn.Module): | |||||||
|             actual, |             actual, | ||||||
|             """\ |             """\ | ||||||
| class GraphModule(torch.nn.Module): | class GraphModule(torch.nn.Module): | ||||||
|     def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"): |     def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"): | ||||||
|         l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_ |         l_self_tensor_constant0 = L_self_tensor_constant0 | ||||||
|  |  | ||||||
|         alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_);  l_self_buffers_tensor_constant0_ = None |         alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0);  l_self_tensor_constant0 = None | ||||||
|  |  | ||||||
|         sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) |         sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default) | ||||||
|  |  | ||||||
| @ -5209,16 +5201,16 @@ class GraphModule(torch.nn.Module): | |||||||
|             actual, |             actual, | ||||||
|             """\ |             """\ | ||||||
| class GraphModule(torch.nn.Module): | class GraphModule(torch.nn.Module): | ||||||
|     def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): |     def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"): | ||||||
|         l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_ |         getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_ | ||||||
|         l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_ |         getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_ | ||||||
|         l_flat_tangents_1_ = L_flat_tangents_1_ |         l_flat_tangents_1_ = L_flat_tangents_1_ | ||||||
|  |  | ||||||
|         _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_);  l_self_modules_fx_const_folded_attrs_parameters_0_ = None |         _new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_);  getattr_l_self_fx_const_folded_attrs_0_ = None | ||||||
|  |  | ||||||
|         copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_);  _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None |         copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_);  _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None | ||||||
|  |  | ||||||
|         mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_);  copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None |         mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_);  copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None | ||||||
|         return (mul_tensor,) |         return (mul_tensor,) | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match: | |||||||
|   >  Left: {0: 0, 1: 1, 2: s1, 3: s0} |   >  Left: {0: 0, 1: 1, 2: s1, 3: s0} | ||||||
|   > Right: {0: 0, 1: 1} |   > Right: {0: 0, 1: 1} | ||||||
| ==> var_to_range: values don't match. | ==> var_to_range: values don't match. | ||||||
|   >  Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]} |   >  Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} | ||||||
|   > Right: {} |   > Right: {} | ||||||
| ==> var_to_sources: values don't match. | ==> var_to_sources: values don't match. | ||||||
|   >  Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} |   >  Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]} | ||||||
| @ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match: | |||||||
|   >  Left: 2 |   >  Left: 2 | ||||||
|   > Right: 0 |   > Right: 0 | ||||||
| ==> var_to_range: values don't match. | ==> var_to_range: values don't match. | ||||||
|   >  Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]} |   >  Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]} | ||||||
|   > Right: {} |   > Right: {} | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
| @ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match: | |||||||
|   >  Left: {s0: 3} |   >  Left: {s0: 3} | ||||||
|   > Right: {} |   > Right: {} | ||||||
| ==> var_to_range: values don't match. | ==> var_to_range: values don't match. | ||||||
|   >  Left: {s0: VR[3, 3], s1: VR[2, int_oo]} |   >  Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]} | ||||||
|   > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} |   > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|         self._replay_and_check(main) |         self._replay_and_check(main) | ||||||
| @ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match: | |||||||
|   >  Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} |   >  Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} | ||||||
|   > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} |   > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} | ||||||
| ==> var_to_range: values don't match. | ==> var_to_range: values don't match. | ||||||
|   >  Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]} |   >  Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]} | ||||||
|   > Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]} |   > Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]} | ||||||
| """, | """, | ||||||
|         ) |         ) | ||||||
|         self._replay_and_check(main) |         self._replay_and_check(main) | ||||||
|  | |||||||
| @ -101,6 +101,15 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): | |||||||
|  |  | ||||||
|         self._common(fn, 2) |         self._common(fn, 2) | ||||||
|  |  | ||||||
|  |     @maybe_skip | ||||||
|  |     def test_mo_getattr_missing(self): | ||||||
|  |         def fn(obj: BaseModelOutput): | ||||||
|  |             if getattr(obj, "asdf", None) is not None: | ||||||
|  |                 obj.asdf += 1 | ||||||
|  |             return obj.attentions + 1 | ||||||
|  |  | ||||||
|  |         self._common(fn, 1) | ||||||
|  |  | ||||||
|     @maybe_skip |     @maybe_skip | ||||||
|     def test_mo_getitem(self): |     def test_mo_getitem(self): | ||||||
|         def fn(obj: BaseModelOutput): |         def fn(obj: BaseModelOutput): | ||||||
| @ -166,6 +175,59 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): | |||||||
|         self.assertEqual(cnts.frame_count, 1) |         self.assertEqual(cnts.frame_count, 1) | ||||||
|         self.assertEqual(cnts.op_count, 2) |         self.assertEqual(cnts.op_count, 2) | ||||||
|  |  | ||||||
|  |     @maybe_skip | ||||||
|  |     def test_mo_init2(self): | ||||||
|  |         # this ModelOutput subclass runs a different __post_init__ codepath | ||||||
|  |         @dataclasses.dataclass | ||||||
|  |         class MyDataClass(ModelOutput): | ||||||
|  |             x: torch.FloatTensor = None | ||||||
|  |  | ||||||
|  |         def fn(x): | ||||||
|  |             obj = MyDataClass(x=x) | ||||||
|  |             return obj | ||||||
|  |  | ||||||
|  |         inp = torch.randn(3, 3) | ||||||
|  |         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||||
|  |         self.assertEqual(fn(inp).x, opt_fn(inp).x) | ||||||
|  |  | ||||||
|  |     @maybe_skip | ||||||
|  |     def test_mo_init_with_disable(self): | ||||||
|  |         # Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>" | ||||||
|  |         # graph breaks (although it may not be the first) | ||||||
|  |         # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 | ||||||
|  |         @dataclasses.dataclass | ||||||
|  |         class MyDataClass(ModelOutput): | ||||||
|  |             x: torch.FloatTensor = None | ||||||
|  |  | ||||||
|  |         @torch._dynamo.disable(recursive=False) | ||||||
|  |         def fn(x): | ||||||
|  |             return MyDataClass(x=x) | ||||||
|  |  | ||||||
|  |         inp = torch.randn(3, 3) | ||||||
|  |         opt_fn = torch._dynamo.optimize("eager")(fn) | ||||||
|  |         self.assertEqual(fn(inp).x, opt_fn(inp).x) | ||||||
|  |  | ||||||
|  |     @maybe_skip | ||||||
|  |     def test_mo_newkey(self): | ||||||
|  |         obj = BaseModelOutput() | ||||||
|  |  | ||||||
|  |         def fn(obj): | ||||||
|  |             return obj["wwww"] + 1 | ||||||
|  |  | ||||||
|  |         inp = torch.randn(3, 3) | ||||||
|  |         obj["wwww"] = inp | ||||||
|  |         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||||
|  |         self.assertEqual(fn(obj), opt_fn(obj)) | ||||||
|  |  | ||||||
|  |     @maybe_skip | ||||||
|  |     def test_mo_from_outside(self): | ||||||
|  |         def fn(obj): | ||||||
|  |             return obj.attentions + 1 | ||||||
|  |  | ||||||
|  |         obj = BaseModelOutput(attentions=torch.randn(3, 3)) | ||||||
|  |         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||||
|  |         self.assertEqual(fn(obj), opt_fn(obj)) | ||||||
|  |  | ||||||
|     @maybe_skip |     @maybe_skip | ||||||
|     def test_HF_bert_model_output(self): |     def test_HF_bert_model_output(self): | ||||||
|         class BertPooler(torch.nn.Module): |         class BertPooler(torch.nn.Module): | ||||||
|  | |||||||
| @ -22,7 +22,6 @@ from torch._dynamo.debug_utils import same_two_models | |||||||
| from torch._dynamo.eval_frame import unsupported | from torch._dynamo.eval_frame import unsupported | ||||||
| from torch._dynamo.mutation_guard import GenerationTracker | from torch._dynamo.mutation_guard import GenerationTracker | ||||||
| from torch._dynamo.testing import expectedFailureDynamic, same | from torch._dynamo.testing import expectedFailureDynamic, same | ||||||
| from torch._dynamo.utils import ifdynstaticdefault |  | ||||||
| from torch.nn.modules.lazy import LazyModuleMixin | from torch.nn.modules.lazy import LazyModuleMixin | ||||||
| from torch.nn.parameter import Parameter, UninitializedParameter | from torch.nn.parameter import Parameter, UninitializedParameter | ||||||
|  |  | ||||||
| @ -1108,37 +1107,6 @@ class UnspecNonInlinableToplevelModule(torch.nn.Module): | |||||||
|         return self.m(x) |         return self.m(x) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ModuleWithIntAttr(torch.nn.Module): |  | ||||||
|     def __init__(self): |  | ||||||
|         super().__init__() |  | ||||||
|         self.layer = torch.nn.Linear(4, 4) |  | ||||||
|         self.step = 10 |  | ||||||
|  |  | ||||||
|     def forward(self, x: torch.Tensor) -> torch.Tensor: |  | ||||||
|         x = x + 1 |  | ||||||
|         self.step += 1 |  | ||||||
|         return self.layer(x) + self.step |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UnspecInlinableModule(torch.nn.Module): |  | ||||||
|     torchdynamo_force_dynamic = True  # forced to be a UnspecializedNNModule |  | ||||||
|  |  | ||||||
|     def forward(self, x): |  | ||||||
|         return torch.sin(x) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UnspecModuleWithIntAttr(torch.nn.Module): |  | ||||||
|     def __init__(self): |  | ||||||
|         super().__init__() |  | ||||||
|         self.layer = UnspecInlinableModule() |  | ||||||
|         self.step = 10 |  | ||||||
|  |  | ||||||
|     def forward(self, x: torch.Tensor) -> torch.Tensor: |  | ||||||
|         x = x + 1 |  | ||||||
|         self.step += 1 |  | ||||||
|         return self.layer(x) + self.step |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def make_test(fn, expected_ops=None): | def make_test(fn, expected_ops=None): | ||||||
|     def test_fn(self): |     def test_fn(self): | ||||||
|         return torch._dynamo.testing.standard_test( |         return torch._dynamo.testing.standard_test( | ||||||
| @ -1392,31 +1360,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase): | |||||||
|         self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) |         self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) | ||||||
|         self.assertTrue(torch._dynamo.testing.same(out1, out_post)) |         self.assertTrue(torch._dynamo.testing.same(out1, out_post)) | ||||||
|  |  | ||||||
|     def test_nn_module_unspec_int_attr(self): |  | ||||||
|         for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]: |  | ||||||
|             mod = module_class() |  | ||||||
|             cnt = torch._dynamo.testing.CompileCounter() |  | ||||||
|             opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod)) |  | ||||||
|             x = torch.randn(3, 4) |  | ||||||
|  |  | ||||||
|             # Compiling self.step as static. |  | ||||||
|             ref1 = mod(x) |  | ||||||
|             res1 = opt_mod(x) |  | ||||||
|             self.assertTrue(torch.allclose(ref1, res1)) |  | ||||||
|             self.assertEqual(cnt.frame_count, 1) |  | ||||||
|  |  | ||||||
|             # Compiling self.step as dynamic. |  | ||||||
|             ref2 = mod(x) |  | ||||||
|             res2 = opt_mod(x) |  | ||||||
|             self.assertTrue(torch.allclose(ref2, res2)) |  | ||||||
|             self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) |  | ||||||
|  |  | ||||||
|             # No re-compilation! |  | ||||||
|             ref3 = mod(x) |  | ||||||
|             res3 = opt_mod(x) |  | ||||||
|             self.assertTrue(torch.allclose(ref3, res3)) |  | ||||||
|             self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1)) |  | ||||||
|  |  | ||||||
|     # RuntimeError: SymIntArrayRef expected to contain only concrete integers |     # RuntimeError: SymIntArrayRef expected to contain only concrete integers | ||||||
|     @expectedFailureDynamic |     @expectedFailureDynamic | ||||||
|     def test_lazy_module1(self): |     def test_lazy_module1(self): | ||||||
|  | |||||||
| @ -201,19 +201,6 @@ class TestDynamismExpression(TestCase): | |||||||
|                 dynamic_shapes={"x": {0: dim_x}}, |                 dynamic_shapes={"x": {0: dim_x}}, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_export_slice_maxsize(self): |  | ||||||
|         class Slice(torch.nn.Module): |  | ||||||
|             def forward(self, *args): |  | ||||||
|                 return torch.ops.aten.slice.Tensor(*args) |  | ||||||
|  |  | ||||||
|         inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) |  | ||||||
|         dynamic_shapes = (({0: Dim("dim")}, None, None, None),) |  | ||||||
|         torch.export.export( |  | ||||||
|             Slice(), |  | ||||||
|             inp, |  | ||||||
|             dynamic_shapes=dynamic_shapes, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_export_constraints_error(self): |     def test_export_constraints_error(self): | ||||||
|         class ConflictingConstraints(torch.nn.Module): |         class ConflictingConstraints(torch.nn.Module): | ||||||
|             def forward(self, x): |             def forward(self, x): | ||||||
| @ -5196,7 +5183,7 @@ def forward(self, x, y): | |||||||
|         } |         } | ||||||
|         export(f, (inputs,), dynamic_shapes=dynamic_shapes) |         export(f, (inputs,), dynamic_shapes=dynamic_shapes) | ||||||
|  |  | ||||||
|     def test_disable_forced_specializations_ok(self): |     def test_disable_forced_specializations(self): | ||||||
|         # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags |         # check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags | ||||||
|         # both behave correctly, avoiding forced specializations and deferring to runtime. |         # both behave correctly, avoiding forced specializations and deferring to runtime. | ||||||
|         # case 1: modulo guards |         # case 1: modulo guards | ||||||
|  | |||||||
| @ -312,6 +312,31 @@ class TestUnflatten(TestCase): | |||||||
|             export_module.module(), unflattened, (torch.randn((2, 3)),) |             export_module.module(), unflattened, (torch.randn((2, 3)),) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") | ||||||
|  |     def test_unflatten_preserve_with_unused_input(self): | ||||||
|  |         class M1(torch.nn.Module): | ||||||
|  |             def forward(self, x, a, b): | ||||||
|  |                 return x + a, b | ||||||
|  |  | ||||||
|  |         class M(torch.nn.Module): | ||||||
|  |             def __init__(self): | ||||||
|  |                 super().__init__() | ||||||
|  |                 self.m1 = M1() | ||||||
|  |  | ||||||
|  |             def forward(self, x, y): | ||||||
|  |                 a, b = torch.topk(y, 2) | ||||||
|  |                 return self.m1(x, a, b)[0] | ||||||
|  |  | ||||||
|  |         ep = torch.export.export( | ||||||
|  |             M(), | ||||||
|  |             (torch.randn(2), torch.randn(5)), | ||||||
|  |             preserve_module_call_signature=("m1",), | ||||||
|  |             strict=False, | ||||||
|  |         ) | ||||||
|  |         ep.graph.eliminate_dead_code() | ||||||
|  |         unflattened = unflatten(ep) | ||||||
|  |         self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) | ||||||
|  |  | ||||||
|     def test_unflatten_wrong_input(self): |     def test_unflatten_wrong_input(self): | ||||||
|         class Mod(torch.nn.Module): |         class Mod(torch.nn.Module): | ||||||
|             def __init__(self): |             def __init__(self): | ||||||
|  | |||||||
							
								
								
									
										53
									
								
								test/fx/test_partitioner_order.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								test/fx/test_partitioner_order.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | # Owner(s): ["module: fx"] | ||||||
|  |  | ||||||
|  | import unittest | ||||||
|  |  | ||||||
|  | from typing import Mapping | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner | ||||||
|  | from torch.fx.passes.operator_support import OperatorSupport | ||||||
|  | from torch.testing._internal.common_utils import TestCase | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DummyDevOperatorSupport(OperatorSupport): | ||||||
|  |     def is_node_supported( | ||||||
|  |         self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node | ||||||
|  |     ) -> bool: | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class DummyPartitioner(CapabilityBasedPartitioner): | ||||||
|  |     def __init__(self, graph_module: torch.fx.GraphModule): | ||||||
|  |         super().__init__( | ||||||
|  |             graph_module, | ||||||
|  |             DummyDevOperatorSupport(), | ||||||
|  |             allows_single_node_partition=True, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class AddModule(torch.nn.Module): | ||||||
|  |     def forward(self, x): | ||||||
|  |         y = torch.add(x, x) | ||||||
|  |         z = torch.add(y, x) | ||||||
|  |         return z | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestPartitionerOrder(TestCase): | ||||||
|  |     # partitoner test to check graph node order | ||||||
|  |     def test_partitioner_order(self): | ||||||
|  |         m = AddModule() | ||||||
|  |         traced_m = torch.fx.symbolic_trace(m) | ||||||
|  |         partions = DummyPartitioner(traced_m).propose_partitions() | ||||||
|  |         partion_nodes = [list(partition.nodes) for partition in partions] | ||||||
|  |         node_order = [n.name for n in partion_nodes[0]] | ||||||
|  |         for _ in range(10): | ||||||
|  |             traced_m = torch.fx.symbolic_trace(m) | ||||||
|  |             new_partion = DummyPartitioner(traced_m).propose_partitions() | ||||||
|  |             new_partion_nodes = [list(partition.nodes) for partition in new_partion] | ||||||
|  |             new_node_order = [n.name for n in new_partion_nodes[0]] | ||||||
|  |             self.assertTrue(node_order == new_node_order) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     unittest.main() | ||||||
| @ -3761,6 +3761,20 @@ class CPUReproTests(TestCase): | |||||||
|             exactly=True, |             exactly=True, | ||||||
|         ).run(code) |         ).run(code) | ||||||
|  |  | ||||||
|  |     def test_repeated_exp(self): | ||||||
|  |         def fn(x): | ||||||
|  |             y = x.sigmoid() | ||||||
|  |             return y + 1, y.sum(-1) | ||||||
|  |  | ||||||
|  |         x = torch.randn(1000, 1000) | ||||||
|  |         opt_fn = torch.compile(fn) | ||||||
|  |         _, code = run_and_get_cpp_code(opt_fn, x) | ||||||
|  |         FileCheck().check_count( | ||||||
|  |             ".exp()", | ||||||
|  |             1, | ||||||
|  |             exactly=True, | ||||||
|  |         ).run(code) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     from torch._inductor.test_case import run_tests |     from torch._inductor.test_case import run_tests | ||||||
|  | |||||||
| @ -5537,6 +5537,14 @@ class CommonTemplate: | |||||||
|         for dtype in all_types(): |         for dtype in all_types(): | ||||||
|             self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) |             self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) | ||||||
|  |  | ||||||
|  |     def test_full_boolean(self): | ||||||
|  |         def fn(n): | ||||||
|  |             x = torch.full((1,), n >= 1024, device=self.device) | ||||||
|  |             return x, x + 1 | ||||||
|  |  | ||||||
|  |         self.common(fn, (1024,)) | ||||||
|  |         self.common(fn, (1023,)) | ||||||
|  |  | ||||||
|     def test_index1(self): |     def test_index1(self): | ||||||
|         def fn(a, b, c): |         def fn(a, b, c): | ||||||
|             return aten.index(a, [b, c]) |             return aten.index(a, [b, c]) | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	