mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
		
			
				
	
	
		
			591 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			591 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <torch/script.h>
 | |
| #include <gtest/gtest.h>
 | |
| #include <test/cpp/api/support.h>
 | |
| 
 | |
| using namespace torch::autograd;
 | |
| using namespace torch::test;
 | |
| 
 | |
| namespace {
 | |
|   torch::Tensor functional_op(torch::Tensor& x) {
 | |
|     return x * x;
 | |
|   }
 | |
| 
 | |
|   void inplace_op(torch::Tensor& x) {
 | |
|     x.mul_(1);
 | |
|   }
 | |
| 
 | |
|   torch::Tensor view_op(torch::Tensor& x) {
 | |
|     return x.view({2, 3});
 | |
|   }
 | |
| 
 | |
|   /*
 | |
|     Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid:
 | |
|       - Autograd=true, ADInplaceOrView=true (normal tensor)
 | |
|       - Autograd=false, ADInplaceOrView=false (inference tensor)
 | |
|     Tensors created in InferenceMode are mostly inference tensors. The only exception
 | |
|     is that view of normal tensors created in InferenceMode still produce normal tensor.
 | |
|   */
 | |
|   void assert_TLS_states(bool inference_mode) {
 | |
|     ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
 | |
|     ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::ADInplaceOrView));
 | |
|     ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
 | |
|     ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
 | |
|     ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::ADInplaceOrView), !inference_mode);
 | |
|     ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestTLSState) {
 | |
|   assert_TLS_states(false);
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     assert_TLS_states(true);
 | |
|     {
 | |
|       InferenceMode guard(false);
 | |
|       assert_TLS_states(false);
 | |
|     }
 | |
|     assert_TLS_states(true);
 | |
|   }
 | |
|   assert_TLS_states(false);
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorCreation) {
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     // New tensor created through constructors are inference tensors.
 | |
|     torch::Tensor c = torch::ones({1, 2, 3});
 | |
|     ASSERT_FALSE(c.requires_grad());
 | |
|     ASSERT_TRUE(c.is_inference());
 | |
| 
 | |
|     // requires_grad doesn't change inference tensor behavior inside InferenceMode.
 | |
|     torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
 | |
|     ASSERT_TRUE(tmp.requires_grad());
 | |
|     ASSERT_TRUE(tmp.is_inference());
 | |
| 
 | |
|     tmp = torch::ones({1, 2, 3}).set_requires_grad(false);
 | |
|     ASSERT_FALSE(tmp.requires_grad());
 | |
|     ASSERT_TRUE(tmp.is_inference());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestExistingAutogradSession) {
 | |
|   torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
 | |
|   torch::Tensor a = s.clone();
 | |
| 
 | |
|   // Save `a` in an existing autograd session
 | |
|   torch::Tensor out = a * a;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     inplace_op(a);
 | |
|   }
 | |
|   // Performing backward should trigger error since `a`'s version has been bumped.
 | |
|   ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)),
 | |
|     "one of the variables needed for gradient computation has been modified by an inplace operation")
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
 | |
|   c10::InferenceMode guard;
 | |
|   for (bool requires_grad : {true, false}) {
 | |
|     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
| 
 | |
|     torch::Tensor func_out = functional_op(c);  // go through kernels: CPU
 | |
|     ASSERT_TRUE(func_out.is_inference());
 | |
|     ASSERT_FALSE(func_out.requires_grad());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
 | |
|   c10::InferenceMode guard;
 | |
|   for (bool requires_grad : {true, false}) {
 | |
|     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
| 
 | |
|     inplace_op(c);  // go through kernels: CPU
 | |
|     ASSERT_TRUE(c.is_inference());
 | |
|     ASSERT_EQ(c.requires_grad(), requires_grad);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
 | |
|   c10::InferenceMode guard;
 | |
|   for (bool requires_grad : {true, false}) {
 | |
|     torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
| 
 | |
|     torch::Tensor view_out = view_op(c);  // go through kernels: CPU
 | |
|     ASSERT_TRUE(view_out.is_inference());
 | |
|     // Note this is different from NoGradMode but makes sense.
 | |
|     ASSERT_FALSE(view_out.requires_grad());
 | |
|     ASSERT_FALSE(view_out.is_view());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
 | |
|   torch::Tensor inference_tensor;
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     }
 | |
| 
 | |
|     // Due to issue #54614, this might run slower compared to InferenceMode since
 | |
|     // intermediate tensors are normal tensors, and they might dispatch to VariableType
 | |
|     // kernels. This is fine since users can easily fix it by moving
 | |
|     // it inside InferenceMode block.
 | |
|     torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: ADInplaceOrView(fallthrough), CPU
 | |
|     ASSERT_FALSE(tmp.is_inference());
 | |
|     ASSERT_FALSE(tmp.requires_grad());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
 | |
|   torch::Tensor inference_tensor;
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     }
 | |
|     ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: ADInplaceOrView, CPU
 | |
|       "Inplace update to inference tensor outside InferenceMode is not allowed");
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
 | |
|   torch::Tensor inference_tensor;
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     }
 | |
|     torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
 | |
|     ASSERT_TRUE(out.is_inference());
 | |
|     ASSERT_FALSE(out.requires_grad());
 | |
|     ASSERT_FALSE(out.is_view());
 | |
|     ASSERT_TRUE(out.is_leaf());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
| 
 | |
|     {
 | |
|       c10::InferenceMode guard;
 | |
| 
 | |
|       inplace_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(a.is_inference());
 | |
|       ASSERT_EQ(a.requires_grad(), requires_grad);
 | |
| 
 | |
|       // inplace -> inplace
 | |
|       inplace_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(a.is_inference());
 | |
|       ASSERT_EQ(a.requires_grad(), requires_grad);
 | |
| 
 | |
|       // inplace -> inplace -> view
 | |
|       torch::Tensor view_out = view_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(view_out.is_inference());
 | |
|       ASSERT_EQ(view_out.requires_grad(), requires_grad);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
| 
 | |
|     {
 | |
|       c10::InferenceMode guard;
 | |
| 
 | |
|       inplace_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(a.is_inference());
 | |
|       ASSERT_EQ(a.requires_grad(), requires_grad);
 | |
|     }
 | |
| 
 | |
|     torch::Tensor tmp = functional_op(a);  // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
 | |
|     ASSERT_FALSE(tmp.is_inference());
 | |
|     ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
| 
 | |
|     inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
 | |
|     ASSERT_FALSE(a.is_inference());
 | |
|     ASSERT_EQ(a.requires_grad(), requires_grad);
 | |
| 
 | |
|     tmp = view_op(a);  // go through kernels: VariableType, ADInplaceOrView, CPU
 | |
|     ASSERT_FALSE(tmp.is_inference());
 | |
|     ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
|     torch::Tensor view_out, tmp;
 | |
| 
 | |
|     {
 | |
|       c10::InferenceMode guard;
 | |
|       // View ops on normal tensor produce normal tensors as output.
 | |
|       // - For view ops it has both dispatch keys since due to the way we create
 | |
|       //   view Tensors in alias_with_sizes_and_strides:
 | |
|       //   ```
 | |
|       //     auto impl = c10::make_intrusive<TensorImpl>(
 | |
|       //     Storage(self.storage()), self.key_set(), self.dtype());
 | |
|       //   ```
 | |
|       //   In addition, these view output tensors are normal in the sense they
 | |
|       //   have both Autograd and ADInplaceOrView keys. But they're still special
 | |
|       //   since they'll have CreationMeta::INFERENCE_MODE. In other words they behave
 | |
|       //   exactly the same as a view tensor created in no_grad mode.
 | |
| 
 | |
|       view_out = view_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(view_out.is_inference());
 | |
|       assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
 | |
|       ASSERT_EQ(view_out.requires_grad(), requires_grad);
 | |
|       ASSERT_TRUE(view_out.is_leaf());
 | |
| 
 | |
|       // view -> view
 | |
|       tmp = view_op(view_out);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(tmp.is_inference());
 | |
|       assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
 | |
|       ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
|       ASSERT_TRUE(tmp.is_leaf());
 | |
| 
 | |
|       // view -> view -> inplace
 | |
|       inplace_op(tmp);  // kernels: ADInplaceOrView, CPU
 | |
|       assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
 | |
|       ASSERT_FALSE(tmp.is_inference());
 | |
|       ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
|       ASSERT_TRUE(tmp.is_leaf());
 | |
|       ASSERT_EQ(a._version(), tmp._version());
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
|     torch::Tensor view_out, tmp;
 | |
| 
 | |
|     {
 | |
|       c10::InferenceMode guard;
 | |
|       view_out = view_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|       ASSERT_FALSE(view_out.is_inference());
 | |
|       assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
 | |
|       ASSERT_EQ(view_out.requires_grad(), requires_grad);
 | |
|       ASSERT_TRUE(view_out.is_leaf());
 | |
|     }
 | |
| 
 | |
|     tmp = functional_op(view_out);
 | |
|     ASSERT_FALSE(view_out.is_inference());
 | |
|     ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
| 
 | |
|     if (requires_grad) {
 | |
|       ASSERT_THROWS_WITH(inplace_op(view_out),  // go through kernels: VariableType, ADInplaceOrView, CPU
 | |
|         "A view was created in inference mode and is being modified inplace")
 | |
|     } else {
 | |
|       inplace_op(view_out);
 | |
|     }
 | |
| 
 | |
|     tmp = view_op(view_out);
 | |
|     ASSERT_FALSE(view_out.is_inference());
 | |
|     ASSERT_EQ(tmp.requires_grad(), requires_grad);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor c;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     }
 | |
| 
 | |
|     // add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward.
 | |
|     torch::Tensor out = c.add(s);  // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
 | |
|     ASSERT_FALSE(out.is_inference());
 | |
|     ASSERT_EQ(out.requires_grad(), requires_grad);
 | |
|     if (requires_grad) {
 | |
|       // leaf inference tensor with requires_grad=true can still have gradient.
 | |
|       // Note this behavior is different from NoGradMode which has empty grad.
 | |
|       out.backward(torch::ones_like(out));
 | |
|       assert_tensor_equal(c.grad(), torch::ones_like(c));
 | |
|     }
 | |
| 
 | |
|     if (requires_grad) {
 | |
|       // mul(self, other) saves variable when requires_grad=true
 | |
|       ASSERT_THROWS_WITH(c.mul(s),
 | |
|         "Inference tensors cannot be saved for backward.");
 | |
| 
 | |
|       // Inference tensor in TensorList input
 | |
|       std::vector<torch::Tensor> inputs = {s, c};
 | |
|       ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, ADInplaceOrView(fallthrough), CPU
 | |
|         "Inference tensors cannot be saved for backward.")
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
|     torch::Tensor c;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       c = torch::ones({1, 2, 3});
 | |
|     }
 | |
| 
 | |
|     if (requires_grad) {
 | |
|       ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU
 | |
|         "Inference tensors cannot be saved for backward.");
 | |
| 
 | |
|       ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), ADInplaceOrView, CPU
 | |
|         "out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
 | |
|     } else {
 | |
|       a.mul_(c);
 | |
| 
 | |
|       ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, ADInplaceOrView(ERROR!), CPU
 | |
|         "Inplace update to inference tensor outside InferenceMode is not allowed");
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor c;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       c = torch::ones({1, 2, 3});
 | |
|     }
 | |
| 
 | |
|     // view_as is a composite op which calls view() with only one tensor argument.
 | |
|     // So there isn't a mixed inference tensor and normal tensor inputs for view ops.
 | |
|     torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU
 | |
|     ASSERT_TRUE(tmp1.is_inference());
 | |
|     ASSERT_FALSE(tmp1.requires_grad());
 | |
| 
 | |
|     // This is fine since it's equivalent as s.view(c.sizes()) which
 | |
|     // isn't a mixed input scenario.
 | |
|     torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
 | |
|     ASSERT_FALSE(tmp2.is_inference());
 | |
|     ASSERT_EQ(tmp2.requires_grad(), requires_grad);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
|     torch::Tensor view_out;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       view_out = view_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|     }
 | |
|     if (requires_grad) {
 | |
|       ASSERT_THROWS_WITH(inplace_op(view_out),
 | |
|         "A view was created in inference mode and is being modified inplace")
 | |
|     } else {
 | |
|       inplace_op(view_out);
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor a = s.clone();
 | |
|     torch::Tensor view_out;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       view_out = view_op(a);  // go through kernels: ADInplaceOrView, CPU
 | |
|     }
 | |
|     inplace_op(a);
 | |
|     if (requires_grad) {
 | |
|       ASSERT_THROWS_WITH(view_out.grad_fn(),
 | |
|         "A view was created in inference mode and its base or another view of its base has been modified inplace");
 | |
|     } else {
 | |
|       view_out.grad_fn();
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestCreationMetaPropagation) {
 | |
|   torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
 | |
|   torch::Tensor b, c;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     b = s.view_as(s);
 | |
|   }
 | |
|   ASSERT_THROWS_WITH(b.add_(1),
 | |
|     "A view was created in inference mode and is being modified inplace");
 | |
|   {
 | |
|     AutoGradMode mode(false);
 | |
|     c = b.view_as(b);
 | |
|   }
 | |
|   ASSERT_THROWS_WITH(c.add_(1),
 | |
|     "A view was created in inference mode and is being modified inplace");
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
 | |
|   torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true);
 | |
|   auto s_view = s.view_as(s);
 | |
|   std::vector<at::Tensor> b, c;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     b = s_view.split_with_sizes({1, 1});
 | |
| 
 | |
|     s = s.view_as(s);
 | |
|     c = s.split_with_sizes({1, 1});
 | |
|   }
 | |
|   for (auto& b_el: b) {
 | |
|     assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
 | |
|     ASSERT_THROWS_WITH(b_el.add_(1),
 | |
|       "A view was created in inference mode and is being modified inplace");
 | |
|   }
 | |
|   for (auto& c_el: c) {
 | |
|     assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
 | |
|     ASSERT_THROWS_WITH(c_el.add_(1),
 | |
|       "A view was created in inference mode and is being modified inplace");
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
 | |
|   for (bool requires_grad: {true, false}) {
 | |
|     torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
 | |
|     torch::Tensor t;
 | |
|     {
 | |
|       InferenceMode guard;
 | |
|       t = torch::ones({1, 2, 3});
 | |
|       t.copy_(s);
 | |
|       ASSERT_TRUE(t.is_inference());
 | |
|       ASSERT_FALSE(t.requires_grad());
 | |
|     }
 | |
| 
 | |
|     ASSERT_THROWS_WITH(t.copy_(s),
 | |
|       "Inplace update to inference tensor outside InferenceMode is not allowed");
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
 | |
|   torch::Tensor t;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     t = torch::ones({1, 2, 3});
 | |
|   }
 | |
|   t.set_requires_grad(false);
 | |
|   ASSERT_THROWS_WITH(t.set_requires_grad(true),
 | |
|     "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestAccessVersionCounter) {
 | |
|   torch::Tensor t;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     t = torch::ones({1, 2, 3});
 | |
|     ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
 | |
|       "Inference tensors do not track version counter.");
 | |
|     t.unsafeGetTensorImpl()->bump_version();
 | |
|   }
 | |
|   ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
 | |
|     "Inference tensors do not track version counter.");
 | |
|   ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(),
 | |
|     "Inplace update to inference tensor outside InferenceMode is not allowed.");
 | |
|   // Suggested workaround
 | |
|   torch::Tensor c = t.clone();
 | |
|   uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
 | |
|   c.unsafeGetTensorImpl()->bump_version();
 | |
|   ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
 | |
|   torch::Tensor s = torch::ones({1, 2, 3});
 | |
|   torch::Tensor t;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     t = torch::ones({1, 2, 3});
 | |
|     // Testing both copy_ from VariableTypeManual and add_ from generated code.
 | |
|     s.copy_(t);
 | |
|     s.add_(t);
 | |
|     t.add_(s);
 | |
|     t.copy_(s);
 | |
|   }
 | |
|   s.copy_(t);
 | |
|   s.add_(t);
 | |
|   ASSERT_THROWS_WITH(t.copy_(s),
 | |
|     "Inplace update to inference tensor outside InferenceMode is not allowed");
 | |
| 
 | |
|   ASSERT_THROWS_WITH(t.add_(s),
 | |
|     "Inplace update to inference tensor outside InferenceMode is not allowed");
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
 | |
|   torch::Tensor s = torch::ones({3, 3, 2});
 | |
|   torch::Tensor t = torch::view_as_complex(s);
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     torch::Tensor tmp;
 | |
| 
 | |
|     tmp = torch::view_as_real(t);
 | |
|     ASSERT_FALSE(tmp.is_inference());
 | |
|     tmp = torch::view_as_complex(s);
 | |
|     ASSERT_FALSE(tmp.is_inference());
 | |
| 
 | |
|     torch::Tensor e = torch::ones({3, 3, 2});
 | |
|     tmp = torch::view_as_complex(e);
 | |
|     ASSERT_TRUE(tmp.is_inference());
 | |
|     tmp = torch::view_as_real(tmp);
 | |
|     ASSERT_TRUE(tmp.is_inference());
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestComplexViewInNormalMode) {
 | |
|   torch::Tensor s;
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     s = torch::ones({3, 3, 2});
 | |
|   }
 | |
|   torch::Tensor tmp = torch::view_as_complex(s);
 | |
|   ASSERT_TRUE(tmp.is_inference());
 | |
|   tmp = torch::view_as_real(tmp);
 | |
|   ASSERT_TRUE(tmp.is_inference());
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestCustomFunction) {
 | |
|   struct MyFunction : public Function<MyFunction> {
 | |
|     static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
 | |
|       ctx->saved_data["mul"] = mul;
 | |
|       ctx->save_for_backward({var1, var2});
 | |
|       return var1 + mul*var2 + var1*var2;
 | |
|     }
 | |
| 
 | |
|     static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
 | |
|       int mul = ctx->saved_data["mul"].toInt();
 | |
|       auto saved = ctx->get_saved_variables();
 | |
|       auto var1 = saved[0];
 | |
|       auto var2 = saved[1];
 | |
|       variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
 | |
|       return output;
 | |
|     }
 | |
|   };
 | |
| 
 | |
|   {
 | |
|     InferenceMode guard;
 | |
|     torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
 | |
|     auto var2 = var1.clone();
 | |
|     int mul = 2;
 | |
|     // If InferenceMode didn't set NoGradGuard automatically, this line
 | |
|     // would error out when trying to save `var1` and `var2` for backward.
 | |
|     auto y = MyFunction::apply(var1, mul, var2);
 | |
|     torch::Tensor expected = var1 + mul * var2 + var1 * var2;
 | |
|     assert_tensor_equal(y, expected);
 | |
|   }
 | |
| }
 | |
| 
 | |
| TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
 | |
|   c10::Warning::WarnAlways warn_always(true);
 | |
|   WarningCapture warnings;
 | |
|   at::AutoNonVariableTypeMode guard;
 | |
|   ASSERT_TRUE(
 | |
|     warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos);
 | |
| }
 |