mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			13 Commits
		
	
	
		
			v1.10.0-rc
			...
			v1.10.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 302ee7bfb6 | |||
| 0c91a7063d | |||
| eadb03895a | |||
| 8416d630c9 | |||
| c78ceadbb0 | |||
| 70af72c794 | |||
| 36449ea931 | |||
| b544cbddfa | |||
| ddf3092581 | |||
| cc360fa38f | |||
| 3c134b8b1e | |||
| 4a514dd81e | |||
| c3ea586e32 | 
							
								
								
									
										2
									
								
								.github/workflows/test_tools.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/test_tools.yml
									
									
									
									
										vendored
									
									
								
							| @ -14,7 +14,7 @@ jobs: | ||||
|       - name: Setup Python | ||||
|         uses: actions/setup-python@v2 | ||||
|         with: | ||||
|           python-version: 3.x | ||||
|           python-version: '3.6 - 3.9' | ||||
|           architecture: x64 | ||||
|       - name: Checkout PyTorch | ||||
|         uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 | ||||
|  | ||||
| @ -170,9 +170,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then | ||||
|   # JIT C++ extensions require ninja, so put it into PATH. | ||||
|   export PATH="/var/lib/jenkins/.local/bin:$PATH" | ||||
|   if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then | ||||
|     pip install -q --user flatbuffers==2.0 | ||||
|     wget https://ortpypackage.blob.core.windows.net/ort-nightly/ort_nightly-1.8.0.dev202107131-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl | ||||
|     pip install -q --user ort_nightly-1.8.0.dev202107131-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl | ||||
|     pip install -q --user flatbuffers==2.0 onnxruntime==1.9.0 | ||||
|   fi | ||||
|   "$ROOT_DIR/scripts/onnx/test.sh" | ||||
| fi | ||||
|  | ||||
| @ -230,6 +230,8 @@ test_aten() { | ||||
| test_without_numpy() { | ||||
|   pushd "$(dirname "${BASH_SOURCE[0]}")" | ||||
|   python -c "import sys;sys.path.insert(0, 'fake_numpy');from unittest import TestCase;import torch;x=torch.randn(3,3);TestCase().assertRaises(RuntimeError, lambda: x.numpy())" | ||||
|   # Regression test for https://github.com/pytorch/pytorch/issues/66353 | ||||
|   python -c "import sys;sys.path.insert(0, 'fake_numpy');import torch;print(torch.tensor([torch.tensor(0.), torch.tensor(1.)]))" | ||||
|   popd | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -2,21 +2,12 @@ | ||||
| #include <ATen/native/MathBitFallThroughLists.h> | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| namespace native { | ||||
| struct ConjFallback : MathOpFallback { | ||||
|   ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {} | ||||
|   bool is_bit_set(const Tensor& tensor) override { | ||||
|     return tensor.is_conj(); | ||||
|   } | ||||
|   void _set_bit(const Tensor& tensor, bool value) override { | ||||
|     return tensor._set_conj(value); | ||||
|   } | ||||
|   Tensor resolve_bit(const Tensor& tensor) override { | ||||
|     return at::resolve_conj(tensor); | ||||
|   } | ||||
|   Tensor& math_op_(Tensor& tensor) override { | ||||
|     return at::conj_physical_(tensor); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||||
| @ -60,4 +51,5 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) { | ||||
|   TENSOR_UTILITIES_AND_CONSTRUCTORS(m) | ||||
| } | ||||
|  | ||||
| } | ||||
| } // namespace at | ||||
|  | ||||
| @ -146,6 +146,10 @@ public: | ||||
|   inline operator T*() { | ||||
|     return values; | ||||
|   } | ||||
|   // Return the values as char* for type punning | ||||
|   auto as_bytes() const -> const char* { | ||||
|     return reinterpret_cast<const char*>(values); | ||||
|   } | ||||
|   template <int64_t mask_> | ||||
|   static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) { | ||||
|     int64_t mask = mask_; | ||||
| @ -735,15 +739,33 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) { | ||||
|  | ||||
| #else | ||||
|  | ||||
| template <typename T> | ||||
| auto load(char const* data) -> T { | ||||
|   T ret; | ||||
|   std::memcpy(&ret, data, sizeof(ret)); | ||||
|   return ret; | ||||
| } | ||||
|  | ||||
| template<class T, typename Op> | ||||
| static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) { | ||||
|   static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); | ||||
|   __at_align__ intmax_t buffer[element_no]; | ||||
|   const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a); | ||||
|   const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b); | ||||
|   for (uint32_t i = 0U; i < element_no; ++ i) { | ||||
|     buffer[i] = op(a_ptr[i], b_ptr[i]); | ||||
|   static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); | ||||
|   static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)"); | ||||
|   // We should be using memcpy in order to respect the strict aliasing rule | ||||
|   // see: https://github.com/pytorch/pytorch/issues/66119 | ||||
|   // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 | ||||
|   // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) | ||||
|   const auto* a_data = a.as_bytes(); | ||||
|   const auto* b_data = b.as_bytes(); | ||||
|   // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) | ||||
|   for (auto& out : buffer) { | ||||
|     out = op(load<intmax_t>(a_data), load<intmax_t>(b_data)); | ||||
|     a_data += sizeof(intmax_t); | ||||
|     b_data += sizeof(intmax_t); | ||||
|   } | ||||
|   assert(a_data == a.as_bytes() + sizeof(a)); | ||||
|   assert(b_data == b.as_bytes() + sizeof(b)); | ||||
|   return Vectorized<T>::loadu(buffer); | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -238,7 +238,7 @@ template<template<typename> class normal_kernel, typename RNG> | ||||
| Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) { | ||||
|   TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex"); | ||||
|   TORCH_CHECK( | ||||
|     std.min().ge(0).item<bool>(), | ||||
|     std.numel() == 0 || std.min().ge(0).item<bool>(), | ||||
|     "normal expects all elements of std >= 0.0"); | ||||
|   bool is_deprecated_th_impl = resize_output_for_normal(output, mean, std); | ||||
|   normal_impl_<normal_kernel, RNG>(output, 0, 1, gen); | ||||
|  | ||||
| @ -50,7 +50,8 @@ namespace at { | ||||
|   m.impl("vsplit.array", torch::CppFunction::makeFallthrough()); \ | ||||
|   m.impl("conj", torch::CppFunction::makeFallthrough()); \ | ||||
|   m.impl("_conj", torch::CppFunction::makeFallthrough()); \ | ||||
|   m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); | ||||
|   m.impl("_unsafe_view", torch::CppFunction::makeFallthrough()); \ | ||||
|   m.impl("resize_", torch::CppFunction::makeFallthrough()); | ||||
|  | ||||
| #define TENSOR_UTILITIES_AND_CONSTRUCTORS(m) \ | ||||
|   m.impl("empty_like", torch::CppFunction::makeFallthrough()); \ | ||||
|  | ||||
| @ -3,42 +3,49 @@ | ||||
| #include <ATen/core/op_registration/op_registration.h> | ||||
| #include <ATen/native/UnaryOps.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #include <ATen/native/Resize.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <torch/library.h> | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| namespace native { | ||||
| // This fallback should only be used for operations that are self inverse and have a corresponding tensor | ||||
| // bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit. | ||||
| // Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit. | ||||
| // Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called. | ||||
|  | ||||
| // NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit. | ||||
| struct MathOpFallback { | ||||
|   MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(op_name_) {} | ||||
|   virtual bool is_bit_set(const Tensor&) = 0; | ||||
|   virtual void _set_bit(const Tensor&, bool) = 0; | ||||
|   // materializes the bit, i.e., returns a new tensor tensor containing the true output | ||||
|   // (after performing the math operation corresponding to the tensor bit) if the bit is set to 1 | ||||
|   // else returns self. | ||||
|   virtual Tensor resolve_bit(const Tensor&) = 0; | ||||
|   // in-place operation corresponding to the math op represented by the bit. Im the future if this class | ||||
|   // is generalized for ops that are not self inverse, then this must be replaced by op_inverse_inplace | ||||
|   virtual Tensor& math_op_(Tensor&) = 0; | ||||
|   void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||||
|     // Situations to handle: | ||||
|     //  1. Out-of-place operation.  Easy: materialize all inputs and | ||||
|     //     call it a day. | ||||
|     //  2. Inplace operation.  Desugar x.add_(2) into x.conj_().add_(2).conj_(). | ||||
|     //     Materialize other inputs as in (1). | ||||
|     //  3. out= operation.  Desugar add(x, 2, out=y) into y.copy_(add(x, 2)) | ||||
|     //  Materialize other inputs as in (1). | ||||
|     // | ||||
|     //  It is important to be able to tell if we READ from an argument and if we | ||||
|     //  WRITE from an argument.  Conservative approach is to assume that we always | ||||
|     //  READ from an argument, but in out-of-place operations you can skip | ||||
|     //  conjugating inputs on entry that never get used.  In current schema we | ||||
|     //  can't easily tell if inplace situation has happened, so don't do it. | ||||
|     /* | ||||
|       Situations to handle: | ||||
|         1. Out-of-place operation.  Easy: materialize all inputs and | ||||
|           call it a day. | ||||
|         2. Inplace operation.  Desugar x.add_(2) into x.conj_().add_(2).conj_(). | ||||
|           Materialize other inputs as in (1). | ||||
|         3. out= operation.  Desugar add(x, 2, out=y) into y.copy_(add(x, 2)) | ||||
|         Materialize other inputs as in (1). | ||||
|  | ||||
|         It is important to be able to tell if we READ from an argument and if we | ||||
|         WRITE to an argument.  Conservative approach is to assume that we always | ||||
|         READ from an argument, but in out= operations you can skip | ||||
|         conjugating inputs on entry that never get used. In the current schema we | ||||
|         can't easily tell if the operation is in in-place or out= operation. | ||||
|  | ||||
|         Note: | ||||
|         1. Mutable tensorlists containing tensors whose math bit set to true are disallowed. | ||||
|         2. Mutable tensors with math bit set to true are unconditionally cloned to ensure | ||||
|            correct behavior in the case when the mutable tensor shares memory with non mutable arguments. | ||||
|  | ||||
|            If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory | ||||
|            with these mutable inputs would read into wrong values in the following cases: | ||||
|            1. Non mutable inputs have their math bit set to false. | ||||
|            2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory | ||||
|               with one or more mutable arg(s)) are cloned. | ||||
|            At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs. | ||||
|     */ | ||||
|     const auto& arguments = op.schema().arguments(); | ||||
|     const auto num_arguments = arguments.size(); | ||||
|     const auto stack_start = stack->size() - num_arguments; | ||||
| @ -72,9 +79,8 @@ struct MathOpFallback { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     // Mutable inputs to be tracked separately | ||||
|     std::vector<Tensor> mutable_inputs; | ||||
|  | ||||
|     // Mutable inputs with math bit set to True and their clones | ||||
|     std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones; | ||||
|     for (const auto i : c10::irange(num_arguments)) { | ||||
|       auto& ivalue = (*stack)[stack_start + i]; | ||||
|       if (!(ivalue.isTensor() || ivalue.isTensorList())) { | ||||
| @ -91,31 +97,26 @@ struct MathOpFallback { | ||||
|         if (!is_bit_set(ivalue.toTensor())) { | ||||
|           continue; | ||||
|         } | ||||
|  | ||||
|         auto tensor = std::move(ivalue).toTensor(); | ||||
|         TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), op_name, " fallback does not support meta tensors."); | ||||
|         auto resolved_tensor = at::clone(tensor); | ||||
|         if (mut_arg) { | ||||
|           // TODO: This is a waste if the argument is write only | ||||
|           _set_bit(tensor, false); | ||||
|           math_op_(tensor); | ||||
|           mutable_inputs.emplace_back(tensor); | ||||
|         } else { | ||||
|           tensor = resolve_bit(tensor); | ||||
|           TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ", | ||||
|             op_name, "bit set to true."); | ||||
|           mutable_inputs_with_their_clones.emplace_back(std::make_pair(std::move(tensor), resolved_tensor)); | ||||
|         } | ||||
|         (*stack)[stack_start + i] = std::move(tensor); | ||||
|         (*stack)[stack_start + i] = std::move(resolved_tensor); | ||||
|       } else if (ivalue.isTensorList()) { | ||||
|         auto tensors = std::move(ivalue).toTensorList(); | ||||
|         if (mut_arg) { | ||||
|           for(const auto j : c10::irange(tensors.size())) { | ||||
|             Tensor t = tensors[j]; | ||||
|             _set_bit(t, false); | ||||
|             math_op_(t); | ||||
|             mutable_inputs.emplace_back(t); | ||||
|           } | ||||
|         } else { | ||||
|           for(const auto j : c10::irange(tensors.size())) { | ||||
|             tensors[j] = resolve_bit(tensors[j]); | ||||
|         for(const auto j : c10::irange(tensors.size())) { | ||||
|           const auto& tensor = tensors[j]; | ||||
|           if (!is_bit_set(tensor)) { | ||||
|             continue; | ||||
|           } | ||||
|           TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ", | ||||
|               op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ", | ||||
|               op.schema().name()); | ||||
|           tensors[j] = at::clone(tensor); | ||||
|         } | ||||
|         (*stack)[stack_start + i] = std::move(tensors); | ||||
|       } | ||||
| @ -123,9 +124,22 @@ struct MathOpFallback { | ||||
|  | ||||
|     op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack); | ||||
|  | ||||
|     for (auto& mutable_input : mutable_inputs) { | ||||
|       math_op_(mutable_input); | ||||
|       _set_bit(mutable_input, true); | ||||
|     TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1); | ||||
|  | ||||
|     for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) { | ||||
|       auto& mutable_input =  mut_tensors.first; | ||||
|       auto& cloned_mutable_input =  mut_tensors.second; | ||||
|       auto& ivalue = (*stack)[stack_start]; | ||||
|       auto returned_output = std::move(ivalue).toTensor(); | ||||
|  | ||||
|       // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input | ||||
|       TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output)); | ||||
|  | ||||
|       // necessary for out= arg | ||||
|       at::native::resize_output(mutable_input, returned_output.sizes()); | ||||
|  | ||||
|       mutable_input.copy_(returned_output); | ||||
|       (*stack)[stack_start] = std::move(mutable_input); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -134,5 +148,5 @@ struct MathOpFallback { | ||||
|   DispatchKey key; | ||||
|   string op_name; | ||||
| }; | ||||
|  | ||||
| } // namespace at | ||||
| } | ||||
| }// namespace at | ||||
|  | ||||
| @ -2,21 +2,12 @@ | ||||
| #include <ATen/native/MathBitFallThroughLists.h> | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| namespace native { | ||||
| struct NegFallback : MathOpFallback { | ||||
|   NegFallback() : MathOpFallback(DispatchKey::Negative, "negation") {} | ||||
|   bool is_bit_set(const Tensor& tensor) override { | ||||
|     return tensor.is_neg(); | ||||
|   } | ||||
|   void _set_bit(const Tensor& tensor, bool value) override { | ||||
|     return tensor._set_neg(value); | ||||
|   } | ||||
|   Tensor resolve_bit(const Tensor& tensor) override { | ||||
|     return at::resolve_neg(tensor); | ||||
|   } | ||||
|   Tensor& math_op_(Tensor& tensor) override { | ||||
|     return tensor.neg_(); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| void negationFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) { | ||||
| @ -42,4 +33,5 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) { | ||||
|   TENSOR_UTILITIES_AND_CONSTRUCTORS(m) | ||||
| } | ||||
|  | ||||
| } | ||||
| } // namespace at | ||||
|  | ||||
| @ -211,6 +211,9 @@ const Tensor& indices) { | ||||
|   int64_t osizeH = output_size[0]; | ||||
|   int64_t osizeW = output_size[1]; | ||||
|  | ||||
|   const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options()); | ||||
|   const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options()); | ||||
|  | ||||
|   if (input.ndimension() == 3) { | ||||
|     int64_t sizeD = input.size(0); | ||||
|     int64_t isizeH = input.size(1); | ||||
| @ -223,8 +226,8 @@ const Tensor& indices) { | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] { | ||||
|           scalar_t* input_data = input.data_ptr<scalar_t>(); | ||||
|           scalar_t* output_data = output.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices.data_ptr<int64_t>(); | ||||
|           scalar_t* output_data = output_c.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices_c.data_ptr<int64_t>(); | ||||
|  | ||||
|           // cuda blocks & threads: | ||||
|           int blocksH = (int)(16L / sizeD); | ||||
| @ -268,8 +271,8 @@ const Tensor& indices) { | ||||
|         "adaptive_max_pool2d_cuda", | ||||
|         [&] { | ||||
|           scalar_t* input_data = input_.data_ptr<scalar_t>(); | ||||
|           scalar_t* output_data = output.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices.data_ptr<int64_t>(); | ||||
|           scalar_t* output_data = output_c.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices_c.data_ptr<int64_t>(); | ||||
|  | ||||
|           // cuda blocks & threads: | ||||
|           int blocksH = (int)(16L / sizeD); | ||||
| @ -296,6 +299,13 @@ const Tensor& indices) { | ||||
|           C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||
|         }); | ||||
|   } | ||||
|  | ||||
|   if (!output.is_contiguous()) { | ||||
|     output.copy_(output_c); | ||||
|   } | ||||
|   if (!indices.is_contiguous()) { | ||||
|     indices.copy_(indices_c); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
| @ -322,7 +332,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|   bool atomic = | ||||
|       true; // suboptimal, but without atomic it doesn't pass the tests | ||||
|  | ||||
|   Tensor gradOutput_ = gradOutput.contiguous(); | ||||
|   const at::Tensor gradOutput_ = gradOutput.contiguous(); | ||||
|   const at::Tensor indices_ = indices.contiguous(); | ||||
|   const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options()); | ||||
|  | ||||
|   if (input.ndimension() == 3) { | ||||
|     int64_t sizeD = input.size(0); | ||||
| @ -334,7 +346,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|  | ||||
|     // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0); | ||||
|  | ||||
|     gradInput.zero_(); | ||||
|     gradInput_c.zero_(); | ||||
|  | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         kHalf, | ||||
| @ -342,9 +354,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|         input.scalar_type(), | ||||
|         "adaptive_max_pool2d_backward_cuda", | ||||
|         [&] { | ||||
|           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>(); | ||||
|           scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>(); | ||||
|           scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices.data_ptr<int64_t>(); | ||||
|           int64_t* indices_data = indices_.data_ptr<int64_t>(); | ||||
|  | ||||
|           // cuda blocks & threads: | ||||
|           int blocksH = (int)(16L / sizeD); | ||||
| @ -393,7 +405,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|     int64_t osizeH = gradOutput_.size(2); | ||||
|     int64_t osizeW = gradOutput_.size(3); | ||||
|  | ||||
|     gradInput.zero_(); | ||||
|     gradInput_c.zero_(); | ||||
|  | ||||
|     // bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0); | ||||
|  | ||||
| @ -403,9 +415,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|         input.scalar_type(), | ||||
|         "adaptive_max_pool2d_backward_cuda", | ||||
|         [&] { | ||||
|           scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>(); | ||||
|           scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>(); | ||||
|           scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>(); | ||||
|           int64_t* indices_data = indices.data_ptr<int64_t>(); | ||||
|           int64_t* indices_data = indices_.data_ptr<int64_t>(); | ||||
|  | ||||
|           // cuda blocks & threads: | ||||
|           int blocksH = (int)(16L / sizeD); | ||||
| @ -446,6 +458,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) | ||||
|           } | ||||
|         }); | ||||
|   } | ||||
|  | ||||
|   if (!gradInput.is_contiguous()) { | ||||
|     gradInput.copy_(gradInput_c); | ||||
|   } | ||||
|  } | ||||
| } // at::native | ||||
| } // at | ||||
|  | ||||
| @ -276,11 +276,14 @@ __global__ void nll_loss_forward_reduce_cuda_kernel_2d( | ||||
| void nll_loss_forward_out_cuda_template( | ||||
|     const Tensor& output, | ||||
|     const Tensor& total_weight, | ||||
|     const Tensor& input, | ||||
|     const Tensor& target, | ||||
|     const Tensor& input_, | ||||
|     const Tensor& target_, | ||||
|     const Tensor& weight, | ||||
|     int64_t reduction, | ||||
|     int64_t ignore_index) { | ||||
|   auto input = *input_.expect_contiguous(); | ||||
|   auto target = *target_.expect_contiguous(); | ||||
|  | ||||
|   int64_t n_classes = input.size(-1); | ||||
|   int64_t n_dims = input.dim(); | ||||
|   int64_t batch_size = n_dims == 1 ? 1 : input.size(0); | ||||
| @ -327,9 +330,6 @@ void nll_loss_forward_out_cuda_template( | ||||
|   output.resize_({}); | ||||
|   total_weight.resize_({}); | ||||
|  | ||||
|   auto input_ = input.contiguous(); | ||||
|   auto target_ = target.contiguous(); | ||||
|  | ||||
|   if (n_dims == 1) { | ||||
|     AT_DISPATCH_FLOATING_TYPES_AND2( | ||||
|         at::ScalarType::Half, | ||||
| @ -345,8 +345,8 @@ void nll_loss_forward_out_cuda_template( | ||||
|                     <<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( | ||||
|                         output.data_ptr<scalar_t>(), | ||||
|                         total_weight.data_ptr<scalar_t>(), | ||||
|                         input_.data_ptr<scalar_t>(), | ||||
|                         target_.data_ptr<index_t>(), | ||||
|                         input.data_ptr<scalar_t>(), | ||||
|                         target.data_ptr<index_t>(), | ||||
|                         weight_.defined() ? weight_.data_ptr<scalar_t>() | ||||
|                                           : nullptr, | ||||
|                         reduction == at::Reduction::Mean, | ||||
| @ -374,8 +374,8 @@ void nll_loss_forward_out_cuda_template( | ||||
|                        at::cuda::getCurrentCUDAStream()>>>( | ||||
|                         output.data_ptr<scalar_t>(), | ||||
|                         total_weight.data_ptr<scalar_t>(), | ||||
|                         input_.data_ptr<scalar_t>(), | ||||
|                         target_.data_ptr<index_t>(), | ||||
|                         input.data_ptr<scalar_t>(), | ||||
|                         target.data_ptr<index_t>(), | ||||
|                         weight_.defined() ? weight_.data_ptr<scalar_t>() | ||||
|                                           : nullptr, | ||||
|                         reduction == at::Reduction::Mean, | ||||
| @ -459,14 +459,19 @@ __global__ void nll_loss_backward_reduce_cuda_kernel_2d( | ||||
| }; | ||||
|  | ||||
| void nll_loss_backward_out_cuda_template( | ||||
|     const Tensor& grad_input, | ||||
|     const Tensor& grad_output, | ||||
|     const Tensor& input, | ||||
|     const Tensor& target, | ||||
|     const Tensor& grad_input_, | ||||
|     const Tensor& grad_output_, | ||||
|     const Tensor& input_, | ||||
|     const Tensor& target_, | ||||
|     const Tensor& total_weight, | ||||
|     const Tensor& weight, | ||||
|     int64_t reduction, | ||||
|     int64_t ignore_index) { | ||||
|   auto target = *target_.expect_contiguous(); | ||||
|   auto input = *input_.expect_contiguous(); | ||||
|   auto grad_input = *grad_input_.expect_contiguous(); | ||||
|   auto grad_output = *grad_output_.expect_contiguous(); | ||||
|  | ||||
|   int64_t n_dims = input.dim(); | ||||
|   int64_t n_classes = input.size(-1); | ||||
|   int64_t batch_size = n_dims == 1 ? 1 : input.size(0); | ||||
| @ -508,7 +513,6 @@ void nll_loss_backward_out_cuda_template( | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   auto target_ = target.contiguous(); | ||||
|   TORCH_CHECK(grad_output.numel() == 1); | ||||
|  | ||||
|   if (n_dims == 1) { | ||||
|  | ||||
| @ -15,7 +15,7 @@ import tempfile | ||||
| import time | ||||
| import unittest | ||||
| from itertools import product | ||||
| from typing import Dict, List, Union, Callable | ||||
| from typing import Callable, Dict, List, Union | ||||
| from unittest import mock | ||||
| from unittest.mock import patch | ||||
|  | ||||
| @ -24,25 +24,25 @@ import torch.multiprocessing as mp | ||||
| from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes | ||||
| from torch.distributed.elastic.multiprocessing.api import ( | ||||
|     MultiprocessContext, | ||||
|     SignalException, | ||||
|     RunProcsResult, | ||||
|     SignalException, | ||||
|     Std, | ||||
|     _validate_full_rank, | ||||
|     to_map, | ||||
|     _wrap, | ||||
|     to_map, | ||||
| ) | ||||
| from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     IS_IN_CI, | ||||
|     IS_MACOS, | ||||
|     IS_WINDOWS, | ||||
|     NO_MULTIPROCESSING_SPAWN, | ||||
|     TEST_WITH_ASAN, | ||||
|     TEST_WITH_TSAN, | ||||
|     TEST_WITH_DEV_DBG_ASAN, | ||||
|     IS_IN_CI, | ||||
|     IS_WINDOWS, | ||||
|     IS_MACOS, | ||||
|     TEST_WITH_TSAN, | ||||
|     run_tests, | ||||
|     sandcastle_skip_if, | ||||
| ) | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
|  | ||||
|  | ||||
| class RunProcResultsTest(unittest.TestCase): | ||||
| @ -224,6 +224,7 @@ def start_processes_zombie_test( | ||||
|  | ||||
| # tests incompatible with tsan or asan | ||||
| if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|  | ||||
|     class StartProcessesTest(unittest.TestCase): | ||||
|         def setUp(self): | ||||
|             self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") | ||||
| @ -251,12 +252,15 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|  | ||||
|         def test_to_map(self): | ||||
|             local_world_size = 2 | ||||
|             self.assertEqual({0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size)) | ||||
|             self.assertEqual( | ||||
|                 {0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size) | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 {0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size) | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 {0: Std.ERR, 1: Std.OUT}, to_map({0: Std.ERR, 1: Std.OUT}, local_world_size) | ||||
|                 {0: Std.ERR, 1: Std.OUT}, | ||||
|                 to_map({0: Std.ERR, 1: Std.OUT}, local_world_size), | ||||
|             ) | ||||
|  | ||||
|         def test_invalid_log_dir(self): | ||||
| @ -382,9 +386,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|                     results = pc.wait(period=0.1) | ||||
|                     self.assertEqual({0: None, 1: None}, results.return_values) | ||||
|  | ||||
|         @sandcastle_skip_if( | ||||
|             TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan" | ||||
|         ) | ||||
|         @sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan") | ||||
|         def test_function_large_ret_val(self): | ||||
|             # python multiprocessing.queue module uses pipes and actually PipedQueues | ||||
|             # This means that if a single object is greater than a pipe size | ||||
| @ -439,7 +441,9 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|                     self.assertEqual(1, failure.exitcode) | ||||
|                     self.assertEqual("<N/A>", failure.signal_name()) | ||||
|                     self.assertEqual(pc.pids()[0], failure.pid) | ||||
|                     self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file) | ||||
|                     self.assertEqual( | ||||
|                         os.path.join(log_dir, "0", "error.json"), error_file | ||||
|                     ) | ||||
|                     self.assertEqual( | ||||
|                         int(error_file_data["message"]["extraInfo"]["timestamp"]), | ||||
|                         int(failure.timestamp), | ||||
| @ -541,17 +545,22 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|                 run_result = mp_context._poll() | ||||
|                 self.assertEqual(1, len(run_result.failures)) | ||||
|                 failure = run_result.failures[0] | ||||
|                 self.assertEqual("Signal 1 (SIGHUP) received by PID 123", failure.message) | ||||
|                 self.assertEqual( | ||||
|                     "Signal 1 (SIGHUP) received by PID 123", failure.message | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows | ||||
| if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|  | ||||
|     class StartProcessesListTest(StartProcessesTest): | ||||
|         ######################################## | ||||
|         # start_processes as binary tests | ||||
|         ######################################## | ||||
|         def test_function(self): | ||||
|             for start_method, redirs in product(self._start_methods, redirects_oss_test()): | ||||
|             for start_method, redirs in product( | ||||
|                 self._start_methods, redirects_oss_test() | ||||
|             ): | ||||
|                 with self.subTest(start_method=start_method, redirs=redirs): | ||||
|                     pc = start_processes( | ||||
|                         name="echo", | ||||
| @ -644,6 +653,7 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): | ||||
|  | ||||
| # tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows | ||||
| if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI): | ||||
|  | ||||
|     class StartProcessesNotCITest(StartProcessesTest): | ||||
|         def test_wrap_bad(self): | ||||
|             none = "" | ||||
| @ -796,7 +806,8 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI): | ||||
|                     self.assertEqual(pc.pids()[0], failure.pid) | ||||
|                     self.assertEqual("<N/A>", error_file) | ||||
|                     self.assertEqual( | ||||
|                         f"Process failed with exitcode {FAIL}", failure.message | ||||
|                         "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html", | ||||
|                         failure.message, | ||||
|                     ) | ||||
|                     self.assertLessEqual(failure.timestamp, int(time.time())) | ||||
|  | ||||
|  | ||||
| @ -115,7 +115,10 @@ class ApiTest(unittest.TestCase): | ||||
|         pf = self.failure_without_error_file(exitcode=138) | ||||
|         self.assertEqual("<N/A>", pf.signal_name()) | ||||
|         self.assertEqual("<N/A>", pf.error_file) | ||||
|         self.assertEqual("Process failed with exitcode 138", pf.message) | ||||
|         self.assertEqual( | ||||
|             "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html", | ||||
|             pf.message, | ||||
|         ) | ||||
|  | ||||
|     def test_child_failed_error(self): | ||||
|         pf0 = self.failure_with_error_file(exception=SentinelError("rank 0")) | ||||
| @ -134,7 +137,7 @@ class ApiTest(unittest.TestCase): | ||||
|           rank: 0 (local_rank: 0) | ||||
|           exitcode: 1 (pid: 997) | ||||
|           error_file: /tmp/ApiTesttbb37ier/error.json | ||||
|           msg: "SentinelError: rank 0" | ||||
|           traceback: "SentinelError: rank 0" | ||||
|         ============================================= | ||||
|         Other Failures: | ||||
|         [1]: | ||||
| @ -148,7 +151,7 @@ class ApiTest(unittest.TestCase): | ||||
|           rank: 2 (local_rank: 0) | ||||
|           exitcode: 138 (pid: 997) | ||||
|           error_file: <N/A> | ||||
|           msg: "Process failed with exitcode 138" | ||||
|           traceback: To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html | ||||
|         ********************************************* | ||||
|         """ | ||||
|         print(ex) | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| from inspect import signature | ||||
| from itertools import product | ||||
| from inspect import signature, isgenerator | ||||
| from copy import deepcopy | ||||
| import tempfile | ||||
|  | ||||
| @ -205,6 +206,116 @@ class TestModule(TestCase): | ||||
|             output_ip.backward(grad) | ||||
|             self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) | ||||
|  | ||||
|     def _traverse_obj(self, obj, func): | ||||
|         if isinstance(obj, (tuple, list)): | ||||
|             return type(obj)(self._traverse_obj(o, func) for o in obj) | ||||
|         elif isgenerator(obj): | ||||
|             return tuple(self._traverse_obj(o, func) for o in obj) | ||||
|         elif isinstance(obj, dict): | ||||
|             return {name: self._traverse_obj(o, func) for name, o in obj.items()} | ||||
|         elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)): | ||||
|             return func(obj) | ||||
|  | ||||
|     def _retain_grad(self, obj): | ||||
|         # gradients needs to be retained to check for grad. This is useful when | ||||
|         # non-leafs are present in the graph. | ||||
|         def inner_retain_grad(obj): | ||||
|             if obj.requires_grad: | ||||
|                 obj.retain_grad() | ||||
|         self._traverse_obj(obj, inner_retain_grad) | ||||
|  | ||||
|     def _get_grads(self, obj): | ||||
|         def inner_get_grad(obj): | ||||
|             if obj.requires_grad: | ||||
|                 return obj.grad | ||||
|         return self._traverse_obj(obj, inner_get_grad) | ||||
|  | ||||
|     def _zero_grad(self, obj): | ||||
|         def inner_zero_grad(obj): | ||||
|             if obj.grad is not None: | ||||
|                 obj.grad = None | ||||
|         self._traverse_obj(obj, inner_zero_grad) | ||||
|  | ||||
|     @modules(module_db) | ||||
|     def test_non_contiguous_tensors(self, device, dtype, module_info): | ||||
|         # Check modules work with non-contiguous tensors | ||||
|  | ||||
|         module_cls = module_info.module_cls | ||||
|         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, | ||||
|                                                        requires_grad=True) | ||||
|  | ||||
|         def _make_non_contiguous(obj): | ||||
|             def inner_make_non_contiguous(obj): | ||||
|                 # Scalar tensors can not be made non-contiguous | ||||
|                 if not isinstance(obj, torch.Tensor) or obj.dim() == 0: | ||||
|                     return obj | ||||
|  | ||||
|                 out = torch.repeat_interleave(obj, 2, dim=-1) | ||||
|                 out = out[..., ::2].detach() | ||||
|                 out.requires_grad = obj.requires_grad | ||||
|                 return out | ||||
|             return self._traverse_obj(obj, inner_make_non_contiguous) | ||||
|  | ||||
|         def _can_be_noncontiguous(obj): | ||||
|             if isinstance(obj, (tuple, list)): | ||||
|                 return any(_can_be_noncontiguous(o) for o in obj) | ||||
|             elif isinstance(obj, dict): | ||||
|                 return any(_can_be_noncontiguous(o) for o in obj.values()) | ||||
|             # scalar tensors can not be non-contiguous | ||||
|             if not isinstance(obj, torch.Tensor) or obj.dim() == 0: | ||||
|                 return False | ||||
|             return True | ||||
|  | ||||
|  | ||||
|         for module_input in module_inputs: | ||||
|             if module_input.forward_input is None: | ||||
|                 continue | ||||
|  | ||||
|             input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs | ||||
|             if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): | ||||
|                 continue | ||||
|  | ||||
|             # === Instantiate the module. === | ||||
|             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs | ||||
|             m = module_cls(*args, **kwargs) | ||||
|             m.to(device).to(dtype) | ||||
|  | ||||
|             self._retain_grad((input_args, input_kwargs)) | ||||
|  | ||||
|             # === Forward with default input | ||||
|             with freeze_rng_state(): | ||||
|                 default_output = m(*input_args, **input_kwargs) | ||||
|                 grad_output = default_output.clone().detach_().normal_() | ||||
|                 default_output.backward(grad_output, retain_graph=True) | ||||
|  | ||||
|             default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs))) | ||||
|             default_param_grad = deepcopy([p.grad for p in m.parameters()]) | ||||
|  | ||||
|             # === Construct non-contiguous tensors === | ||||
|             nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs)) | ||||
|             nc_grad_output = _make_non_contiguous(grad_output) | ||||
|  | ||||
|             # === Compare results with non-contiguous and contiguous tensors === | ||||
|             inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] | ||||
|             grads = [grad_output, nc_grad_output] | ||||
|  | ||||
|             for (in_args, in_kwargs), g_out in product(inputs, grads): | ||||
|                 g_out_copy = deepcopy(g_out) | ||||
|                 self._zero_grad((in_args, in_kwargs)) | ||||
|                 self._zero_grad(m.parameters()) | ||||
|  | ||||
|                 with freeze_rng_state(): | ||||
|                     out = m(*in_args, **in_kwargs) | ||||
|                     out.backward(g_out_copy, retain_graph=True) | ||||
|  | ||||
|                 input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs)) | ||||
|                 self.assertEqual(out, default_output) | ||||
|                 self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) | ||||
|                 self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) | ||||
|  | ||||
|                 param_grad = [p.grad for p in m.parameters()] | ||||
|                 self.assertEqual(param_grad, default_param_grad) | ||||
|  | ||||
|  | ||||
| instantiate_device_type_tests(TestModule, globals()) | ||||
|  | ||||
|  | ||||
| @ -14622,7 +14622,6 @@ class TestNNDeviceType(NNTestCase): | ||||
|  | ||||
|                             self.assertEqual(a_cuda.grad, a_cpu.grad) | ||||
|  | ||||
|     @onlyCPU | ||||
|     @dtypes(torch.float, torch.double) | ||||
|     def test_adaptive_pooling_max_nhwc(self, device, dtype): | ||||
|         def helper(n, c, h, w, output_height, output_width, contig): | ||||
|  | ||||
| @ -449,5 +449,15 @@ $6 = torch._ops.aten.add_($1, $5)''') | ||||
|                 with enable_python_mode(LoggingTensor): | ||||
|                     pass | ||||
|  | ||||
|     def test_tolist_numpy_with_python_mode(self) -> None: | ||||
|         x = LoggingTensor(torch.tensor([2.0, 3.0])) | ||||
|         with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): | ||||
|             x.tolist() | ||||
|         with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): | ||||
|             x.numpy() | ||||
|         with self.assertRaises(AssertionError): | ||||
|             self.assertEqual(x, None) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     run_tests() | ||||
|  | ||||
| @ -3258,6 +3258,10 @@ class TestRandomTensorCreation(TestCase): | ||||
|             self.assertEqual(t_transform(r[:, :50]).std(), std_transform(4), atol=0.3, rtol=0) | ||||
|             self.assertEqual(t_transform(r[:, 50:]).std(), std_transform(1), atol=0.2, rtol=0) | ||||
|  | ||||
|             # test empty mean/std | ||||
|             out = torch.normal(mean=torch.empty((0, 2)), std=torch.empty((0, 1))) | ||||
|             self.assertEqual(out.size(), torch.Size([0, 2])) | ||||
|  | ||||
|             r.fill_(42) | ||||
|             r = torch.normal(2, 3, (100, 100), dtype=dtype, device=device) | ||||
|             self.assertEqual(r.dtype, dtype) | ||||
|  | ||||
| @ -8399,6 +8399,13 @@ class TestTorch(AbstractTestCases._TestTorchMixin): | ||||
|         finally: | ||||
|             torch.set_num_threads(num_threads) | ||||
|  | ||||
|     def test_conj_neg_tolist(self): | ||||
|         x = torch.randn(2, dtype=torch.cfloat) | ||||
|         y1 = x.conj() | ||||
|         y1_expect = x.conj_physical() | ||||
|         y2 = y1.imag | ||||
|         self.assertEqual(y1, y1_expect.tolist()) | ||||
|         self.assertEqual(y2, y1_expect.imag.tolist()) | ||||
|  | ||||
| # TODO: these empy classes are temporarily instantiated for XLA compatibility | ||||
| #   once XLA updates their test suite it should be removed | ||||
|  | ||||
| @ -365,6 +365,16 @@ class TestViewOps(TestCase): | ||||
|             self.assertEqual(v_imag, t_numpy_conj.imag) | ||||
|             self.assertTrue(v_imag.is_neg()) | ||||
|  | ||||
|     @onlyOnCPUAndCUDA | ||||
|     def test_conj_view_with_shared_memory(self, device) -> None: | ||||
|         a = _make_tensor((4, 5,), torch.cfloat, device) | ||||
|         b = a.conj() | ||||
|         c = a.conj() | ||||
|  | ||||
|         self.assertEqual(torch.add(a, b), a.add_(b)) | ||||
|         self.assertEqual(torch.add(b, c), torch.add(b, c, out=a)) | ||||
|         self.assertEqual(torch.add(b, c), b.add_(c)) | ||||
|  | ||||
|     @onlyOnCPUAndCUDA | ||||
|     @dtypes(*product(get_all_complex_dtypes(), get_all_dtypes())) | ||||
|     @suppress_warnings | ||||
|  | ||||
| @ -7,6 +7,13 @@ from tools.codegen.code_template import CodeTemplate | ||||
|  | ||||
| import yaml | ||||
|  | ||||
| # Safely load fast C Yaml loader/dumper if they are available | ||||
| try: | ||||
|     from yaml import CSafeLoader as Loader | ||||
| except ImportError: | ||||
|     from yaml import SafeLoader as Loader  # type: ignore[misc] | ||||
|  | ||||
|  | ||||
| if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) { | ||||
|   return $dtype_checks; | ||||
| }""" | ||||
| @ -121,7 +128,7 @@ def main() -> None: | ||||
|     print("Loading yaml file: ", model_file_name) | ||||
|     loaded_model = {} | ||||
|     with open(model_file_name, "rb") as model_file: | ||||
|         loaded_model = yaml.load(model_file) | ||||
|         loaded_model = yaml.load(model_file, Loader=Loader) | ||||
|  | ||||
|  | ||||
|     root_operators_set = set(loaded_model) | ||||
|  | ||||
| @ -30,7 +30,8 @@ static PyObject* recursive_to_list( | ||||
| } | ||||
|  | ||||
| PyObject* tensor_to_list(const Tensor& tensor) { | ||||
|   Tensor data = tensor; | ||||
|   TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".tolist() is not supported for tensor subclasses."); | ||||
|   Tensor data = tensor.resolve_conj().resolve_neg(); | ||||
|   if (!data.device().is_cpu()) { | ||||
|     pybind11::gil_scoped_release no_gil; | ||||
|     data = data.toBackend(Backend::CPU); | ||||
|  | ||||
| @ -194,7 +194,7 @@ void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t | ||||
|   PyObject** items = PySequence_Fast_ITEMS(seq.get()); | ||||
|   for(const auto i : c10::irange(n)) { | ||||
| #ifdef USE_NUMPY | ||||
|     if (PyArray_Check(items[i])) { | ||||
|     if (is_numpy_available() && PyArray_Check(items[i])) { | ||||
|       TORCH_WARN_ONCE( | ||||
|         "Creating a tensor from a list of numpy.ndarrays is extremely slow. " | ||||
|         "Please consider converting the list to a single numpy.ndarray with " | ||||
|  | ||||
| @ -130,6 +130,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) { | ||||
|       "Can't call numpy() on Tensor that has negative bit set. " | ||||
|       "Use tensor.resolve_neg().numpy() instead."); | ||||
|  | ||||
|   TORCH_CHECK(!tensor.unsafeGetTensorImpl()->is_python_dispatch(), ".numpy() is not supported for tensor subclasses."); | ||||
|  | ||||
|   auto dtype = aten_to_numpy_dtype(tensor.scalar_type()); | ||||
|   auto sizes = to_numpy_shape(tensor.sizes()); | ||||
|   auto strides = to_numpy_shape(tensor.strides()); | ||||
|  | ||||
| @ -21,10 +21,11 @@ from torch.distributed.elastic.agent.server.api import ( | ||||
|     WorkerState, | ||||
| ) | ||||
| from torch.distributed.elastic.metrics.api import prof | ||||
| from torch.distributed.elastic.multiprocessing import start_processes, PContext | ||||
| from torch.distributed.elastic.multiprocessing import PContext, start_processes | ||||
| from torch.distributed.elastic.utils import macros | ||||
| from torch.distributed.elastic.utils.logging import get_logger | ||||
|  | ||||
|  | ||||
| log = get_logger() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -51,6 +51,7 @@ children, and propagates the one with the **smallest** timestamp (e.g. the **fir | ||||
| import json | ||||
| import os | ||||
| import signal | ||||
| import socket | ||||
| import time | ||||
| import warnings | ||||
| from dataclasses import dataclass, field | ||||
| @ -109,7 +110,7 @@ class ProcessFailure: | ||||
|             try: | ||||
|                 with open(self.error_file, "r") as fp: | ||||
|                     self.error_file_data = json.load(fp) | ||||
|                     log.info( | ||||
|                     log.debug( | ||||
|                         f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}" | ||||
|                     ) | ||||
|                     self.message, self.timestamp = self._get_error_data( | ||||
| @ -130,7 +131,7 @@ class ProcessFailure: | ||||
|                     f" received by PID {self.pid}" | ||||
|                 ) | ||||
|             else: | ||||
|                 self.message = f"Process failed with exitcode {self.exitcode}" | ||||
|                 self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html" | ||||
|  | ||||
|     def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]: | ||||
|         message = error_file_data["message"] | ||||
| @ -162,24 +163,24 @@ class ProcessFailure: | ||||
| GlobalRank = int | ||||
|  | ||||
| _FAILURE_FORMAT_TEMPLATE = """[${idx}]: | ||||
|   time: ${time} | ||||
|   rank: ${rank} (local_rank: ${local_rank}) | ||||
|   exitcode: ${exitcode} (pid: ${pid}) | ||||
|   time      : ${time} | ||||
|   host      : ${hostname} | ||||
|   rank      : ${rank} (local_rank: ${local_rank}) | ||||
|   exitcode  : ${exitcode} (pid: ${pid}) | ||||
|   error_file: ${error_file} | ||||
|   msg: ${message}""" | ||||
|   traceback : ${message}""" | ||||
|  | ||||
| # extra new lines before and after are intentional | ||||
| _MSG_FORMAT_TEMPLATE = """ | ||||
| ${boarder} | ||||
| ${title} | ||||
| ${section} | ||||
| Root Cause: | ||||
| ${root_failure} | ||||
| ${section} | ||||
| Other Failures: | ||||
| Failures: | ||||
| ${other_failures} | ||||
| ${boarder} | ||||
| """ | ||||
| ${section} | ||||
| Root Cause (first observed failure): | ||||
| ${root_failure} | ||||
| ${boarder}""" | ||||
|  | ||||
|  | ||||
| class ChildFailedError(Exception): | ||||
| @ -230,8 +231,8 @@ class ChildFailedError(Exception): | ||||
|         rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp) | ||||
|         return rank, self.failures[rank] | ||||
|  | ||||
|     def format_msg(self, boarder_delim="*", section_delim="="): | ||||
|         title = f"  {self.name} FAILED  " | ||||
|     def format_msg(self, boarder_delim="=", section_delim="-"): | ||||
|         title = f"{self.name} FAILED" | ||||
|         root_rank, root_failure = self.get_first_failure() | ||||
|  | ||||
|         root_failure_fmt: str = "" | ||||
| @ -246,11 +247,11 @@ class ChildFailedError(Exception): | ||||
|                 other_failures_fmt.append(fmt) | ||||
|  | ||||
|         # upper boundary on width | ||||
|         width = min(width, 80) | ||||
|         width = min(width, 60) | ||||
|  | ||||
|         return Template(_MSG_FORMAT_TEMPLATE).substitute( | ||||
|             boarder=boarder_delim * width, | ||||
|             title=title.center(width), | ||||
|             title=title, | ||||
|             section=section_delim * width, | ||||
|             root_failure=root_failure_fmt, | ||||
|             other_failures="\n".join(other_failures_fmt or ["  <NO_OTHER_FAILURES>"]), | ||||
| @ -279,6 +280,7 @@ class ChildFailedError(Exception): | ||||
|         fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute( | ||||
|             idx=idx, | ||||
|             time=failure.timestamp_isoformat(), | ||||
|             hostname=socket.getfqdn(), | ||||
|             rank=rank, | ||||
|             local_rank=failure.local_rank, | ||||
|             exitcode=failure.exitcode, | ||||
| @ -292,32 +294,6 @@ class ChildFailedError(Exception): | ||||
|         return fmt, width | ||||
|  | ||||
|  | ||||
| def _no_error_file_warning_msg(rank: int, failure: ProcessFailure) -> str: | ||||
|     msg = [ | ||||
|         "CHILD PROCESS FAILED WITH NO ERROR_FILE", | ||||
|         f"Child process {failure.pid} (local_rank {rank}) FAILED (exitcode {failure.exitcode})", | ||||
|         f"Error msg: {failure.message}", | ||||
|         f"Without writing an error file to {failure.error_file}.", | ||||
|         "While this DOES NOT affect the correctness of your application,", | ||||
|         "no trace information about the error will be available for inspection.", | ||||
|         "Consider decorating your top level entrypoint function with", | ||||
|         "torch.distributed.elastic.multiprocessing.errors.record. Example:", | ||||
|         "", | ||||
|         r"  from torch.distributed.elastic.multiprocessing.errors import record", | ||||
|         "", | ||||
|         r"  @record", | ||||
|         r"  def trainer_main(args):", | ||||
|         r"     # do train", | ||||
|     ] | ||||
|     width = 0 | ||||
|     for line in msg: | ||||
|         width = max(width, len(line)) | ||||
|  | ||||
|     boarder = "*" * width | ||||
|     header = "CHILD PROCESS FAILED WITH NO ERROR_FILE".center(width) | ||||
|     return "\n".join(["\n", boarder, header, boarder, *msg, boarder]) | ||||
|  | ||||
|  | ||||
| def record( | ||||
|     fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None | ||||
| ) -> Callable[..., T]: | ||||
| @ -372,7 +348,13 @@ def record( | ||||
|                 if failure.error_file != _NOT_AVAILABLE: | ||||
|                     error_handler.dump_error_file(failure.error_file, failure.exitcode) | ||||
|                 else: | ||||
|                     warnings.warn(_no_error_file_warning_msg(rank, failure)) | ||||
|                     log.info( | ||||
|                         ( | ||||
|                             f"local_rank {rank} FAILED with no error file." | ||||
|                             f" Decorate your entrypoint fn with @record for traceback info." | ||||
|                             f" See: https://pytorch.org/docs/stable/elastic/errors.html" | ||||
|                         ) | ||||
|                     ) | ||||
|                 raise | ||||
|             except Exception as e: | ||||
|                 error_handler.record_exception(e) | ||||
|  | ||||
| @ -107,7 +107,7 @@ class ErrorHandler: | ||||
|                 else: | ||||
|                     rootcause_error["message"]["errorCode"] = error_code | ||||
|  | ||||
|             log.info( | ||||
|             log.debug( | ||||
|                 f"child error file ({rootcause_error_file}) contents:\n" | ||||
|                 f"{json.dumps(rootcause_error, indent=2)}" | ||||
|             ) | ||||
|  | ||||
| @ -304,6 +304,27 @@ utility | ||||
|  | ||||
|       if should_checkpoint: | ||||
|         save_checkpoint(checkpoint_path) | ||||
|  | ||||
| 9. (Recommended) On worker errors, this tool will summarize the details of the error | ||||
|    (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) | ||||
|    is heuristically reported as the "Root Cause" error. To get tracebacks as part of this | ||||
|    error summary print out, you must decorate your main entrypoint function in your | ||||
|    training script as shown in the example below. If not decorated, then the summary | ||||
|    will not include the traceback of the exception and will only contain the exitcode. | ||||
|    For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html | ||||
|  | ||||
| :: | ||||
|  | ||||
|   from torch.distributed.elastic.multiprocessing.errors import record | ||||
|  | ||||
|   @record | ||||
|   def main(): | ||||
|       # do train | ||||
|       pass | ||||
|  | ||||
|   if __name__ == "__main__": | ||||
|       main() | ||||
|  | ||||
| """ | ||||
| import logging | ||||
| import os | ||||
| @ -597,7 +618,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str | ||||
|     if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: | ||||
|         omp_num_threads = 1 | ||||
|         log.warning( | ||||
|             f"*****************************************\n" | ||||
|             f"\n*****************************************\n" | ||||
|             f"Setting OMP_NUM_THREADS environment variable for each process to be " | ||||
|             f"{omp_num_threads} in default, to avoid your system being overloaded, " | ||||
|             f"please further tune the variable for optimal performance in " | ||||
|  | ||||
| @ -180,13 +180,16 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k | ||||
|  | ||||
| def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs): | ||||
|     make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||||
|     make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) | ||||
|  | ||||
|     cases: List[Tuple[str, dict]] = [ | ||||
|         ('', {}), | ||||
|         ('reduction_sum', {'reduction': 'sum'}), | ||||
|         ('reduction_none', {'reduction': 'none'}), | ||||
|         ('ignore_index', {'ignore_index': 2}), | ||||
|         ('weights', {'weight': make_input(10)}), | ||||
|         ('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}), | ||||
|         ('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1}) | ||||
|         ('weights', {'weight': make_weight(10).abs()}), | ||||
|         ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}), | ||||
|         ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1}) | ||||
|     ] | ||||
|     module_inputs = [] | ||||
|     for desc, constructor_kwargs in cases: | ||||
|  | ||||
| @ -1781,8 +1781,10 @@ class TestCase(expecttest.TestCase): | ||||
|         assert (atol is None) == (rtol is None), "If one of atol or rtol is specified, then the other must be too" | ||||
|         debug_msg: Optional[str] = None | ||||
|  | ||||
|         if x is None or y is None: | ||||
|             self.assertTrue(x is None and y is None) | ||||
|         # Tensor x Number and Number x Tensor comparisons | ||||
|         if isinstance(x, torch.Tensor) and isinstance(y, Number): | ||||
|         elif isinstance(x, torch.Tensor) and isinstance(y, Number): | ||||
|             self.assertEqual(x.item(), y, atol=atol, rtol=rtol, msg=msg, | ||||
|                              exact_dtype=exact_dtype, exact_device=exact_device) | ||||
|         elif isinstance(y, torch.Tensor) and isinstance(x, Number): | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	