[lint] autoformat test/cpp and torch/csrc

Let's have some fun.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828

Approved by: https://github.com/ezyang
This commit is contained in:
Michael Suo
2022-06-11 10:22:58 -07:00
committed by PyTorch MergeBot
parent 1ec30a6647
commit 30fb2c4aba
564 changed files with 41930 additions and 27082 deletions

View File

@ -1,40 +1,49 @@
#include <torch/script.h>
#include <gtest/gtest.h>
#include <test/cpp/api/support.h>
#include <torch/script.h>
using namespace torch::autograd;
using namespace torch::test;
namespace {
torch::Tensor functional_op(torch::Tensor& x) {
return x * x;
}
void inplace_op(torch::Tensor& x) {
x.mul_(1);
}
torch::Tensor view_op(torch::Tensor& x) {
return x.view({2, 3});
}
/*
Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid:
- Autograd=true, ADInplaceOrView=true (normal tensor)
- Autograd=false, ADInplaceOrView=false (inference tensor)
Tensors created in InferenceMode are mostly inference tensors. The only exception
is that view of normal tensors created in InferenceMode still produce normal tensor.
*/
void assert_TLS_states(bool inference_mode) {
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::ADInplaceOrView));
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::ADInplaceOrView), !inference_mode);
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
}
torch::Tensor functional_op(torch::Tensor& x) {
return x * x;
}
void inplace_op(torch::Tensor& x) {
x.mul_(1);
}
torch::Tensor view_op(torch::Tensor& x) {
return x.view({2, 3});
}
/*
Only the following combos of Autograd & ADInplaceOrView keys on tensors are
valid:
- Autograd=true, ADInplaceOrView=true (normal tensor)
- Autograd=false, ADInplaceOrView=false (inference tensor)
Tensors created in InferenceMode are mostly inference tensors. The only
exception is that view of normal tensors created in InferenceMode still
produce normal tensor.
*/
void assert_TLS_states(bool inference_mode) {
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(
c10::DispatchKey::ADInplaceOrView));
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(
c10::autograd_dispatch_keyset));
ASSERT_EQ(
c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset),
inference_mode);
ASSERT_EQ(
c10::impl::tls_is_dispatch_key_included(
c10::DispatchKey::ADInplaceOrView),
!inference_mode);
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
}
} // namespace
TEST(InferenceModeTest, TestTLSState) {
assert_TLS_states(false);
{
@ -57,7 +66,8 @@ TEST(InferenceModeTest, TestInferenceTensorCreation) {
ASSERT_FALSE(c.requires_grad());
ASSERT_TRUE(c.is_inference());
// requires_grad doesn't change inference tensor behavior inside InferenceMode.
// requires_grad doesn't change inference tensor behavior inside
// InferenceMode.
torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
ASSERT_TRUE(tmp.requires_grad());
ASSERT_TRUE(tmp.is_inference());
@ -78,9 +88,11 @@ TEST(InferenceModeTest, TestExistingAutogradSession) {
InferenceMode guard;
inplace_op(a);
}
// Performing backward should trigger error since `a`'s version has been bumped.
ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)),
"one of the variables needed for gradient computation has been modified by an inplace operation")
// Performing backward should trigger error since `a`'s version has been
// bumped.
ASSERT_THROWS_WITH(
out.backward(torch::ones_like(out)),
"one of the variables needed for gradient computation has been modified by an inplace operation")
}
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
@ -88,7 +100,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
ASSERT_TRUE(func_out.is_inference());
ASSERT_FALSE(func_out.requires_grad());
}
@ -99,7 +111,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
inplace_op(c); // go through kernels: CPU
inplace_op(c); // go through kernels: CPU
ASSERT_TRUE(c.is_inference());
ASSERT_EQ(c.requires_grad(), requires_grad);
}
@ -110,7 +122,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor view_out = view_op(c); // go through kernels: CPU
torch::Tensor view_out = view_op(c); // go through kernels: CPU
ASSERT_TRUE(view_out.is_inference());
// Note this is different from NoGradMode but makes sense.
ASSERT_FALSE(view_out.requires_grad());
@ -120,17 +132,20 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
torch::Tensor inference_tensor;
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
{
InferenceMode guard;
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
inference_tensor =
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
}
// Due to issue #54614, this might run slower compared to InferenceMode since
// intermediate tensors are normal tensors, and they might dispatch to VariableType
// kernels. This is fine since users can easily fix it by moving
// it inside InferenceMode block.
torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: ADInplaceOrView(fallthrough), CPU
// Due to issue #54614, this might run slower compared to InferenceMode
// since intermediate tensors are normal tensors, and they might dispatch to
// VariableType kernels. This is fine since users can easily fix it by
// moving it inside InferenceMode block.
torch::Tensor tmp =
functional_op(inference_tensor); // go through kernels:
// ADInplaceOrView(fallthrough), CPU
ASSERT_FALSE(tmp.is_inference());
ASSERT_FALSE(tmp.requires_grad());
}
@ -138,24 +153,29 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
torch::Tensor inference_tensor;
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
{
InferenceMode guard;
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
inference_tensor =
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
}
ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: ADInplaceOrView, CPU
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(
inplace_op(
inference_tensor), // go through kernels: ADInplaceOrView, CPU
"Inplace update to inference tensor outside InferenceMode is not allowed");
}
}
TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
torch::Tensor inference_tensor;
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
{
InferenceMode guard;
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
inference_tensor =
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
}
torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
torch::Tensor out =
view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
ASSERT_TRUE(out.is_inference());
ASSERT_FALSE(out.requires_grad());
ASSERT_FALSE(out.is_view());
@ -164,24 +184,25 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
}
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
{
c10::InferenceMode guard;
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(a.is_inference());
ASSERT_EQ(a.requires_grad(), requires_grad);
// inplace -> inplace
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(a.is_inference());
ASSERT_EQ(a.requires_grad(), requires_grad);
// inplace -> inplace -> view
torch::Tensor view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
torch::Tensor view_out =
view_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(view_out.is_inference());
ASSERT_EQ(view_out.requires_grad(), requires_grad);
}
@ -189,19 +210,20 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
}
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
{
c10::InferenceMode guard;
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(a.is_inference());
ASSERT_EQ(a.requires_grad(), requires_grad);
}
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType,
// ADInplaceOrView(fallthrough), CPU
ASSERT_FALSE(tmp.is_inference());
ASSERT_EQ(tmp.requires_grad(), requires_grad);
@ -209,14 +231,14 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
ASSERT_FALSE(a.is_inference());
ASSERT_EQ(a.requires_grad(), requires_grad);
tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
ASSERT_FALSE(tmp.is_inference());
ASSERT_EQ(tmp.requires_grad(), requires_grad);
}
}
TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor view_out, tmp;
@ -231,25 +253,26 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
// Storage(self.storage()), self.key_set(), self.dtype());
// ```
// In addition, these view output tensors are normal in the sense they
// have both Autograd and ADInplaceOrView keys. But they're still special
// since they'll have CreationMeta::INFERENCE_MODE. In other words they behave
// exactly the same as a view tensor created in no_grad mode.
// have both Autograd and ADInplaceOrView keys. But they're still
// special since they'll have CreationMeta::INFERENCE_MODE. In other
// words they behave exactly the same as a view tensor created in
// no_grad mode.
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(view_out.is_inference());
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
ASSERT_EQ(view_out.requires_grad(), requires_grad);
ASSERT_TRUE(view_out.is_leaf());
// view -> view
tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(tmp.is_inference());
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
ASSERT_EQ(tmp.requires_grad(), requires_grad);
ASSERT_TRUE(tmp.is_leaf());
// view -> view -> inplace
inplace_op(tmp); // kernels: ADInplaceOrView, CPU
inplace_op(tmp); // kernels: ADInplaceOrView, CPU
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
ASSERT_FALSE(tmp.is_inference());
ASSERT_EQ(tmp.requires_grad(), requires_grad);
@ -260,14 +283,14 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
}
TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor view_out, tmp;
{
c10::InferenceMode guard;
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
ASSERT_FALSE(view_out.is_inference());
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
ASSERT_EQ(view_out.requires_grad(), requires_grad);
@ -279,8 +302,10 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
ASSERT_EQ(tmp.requires_grad(), requires_grad);
if (requires_grad) {
ASSERT_THROWS_WITH(inplace_op(view_out), // go through kernels: VariableType, ADInplaceOrView, CPU
"A view was created in inference mode and is being modified inplace")
ASSERT_THROWS_WITH(
inplace_op(view_out), // go through kernels: VariableType,
// ADInplaceOrView, CPU
"A view was created in inference mode and is being modified inplace")
} else {
inplace_op(view_out);
}
@ -292,7 +317,7 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
}
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor c;
{
@ -300,8 +325,10 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
}
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward.
torch::Tensor out = c.add(s); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save
// any variable for backward.
torch::Tensor out = c.add(s); // go through kernels: VariableType,
// ADInplaceOrView(fallthrough), CPU
ASSERT_FALSE(out.is_inference());
ASSERT_EQ(out.requires_grad(), requires_grad);
if (requires_grad) {
@ -313,19 +340,21 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
if (requires_grad) {
// mul(self, other) saves variable when requires_grad=true
ASSERT_THROWS_WITH(c.mul(s),
"Inference tensors cannot be saved for backward.");
ASSERT_THROWS_WITH(
c.mul(s), "Inference tensors cannot be saved for backward.");
// Inference tensor in TensorList input
std::vector<torch::Tensor> inputs = {s, c};
ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, ADInplaceOrView(fallthrough), CPU
"Inference tensors cannot be saved for backward.")
ASSERT_THROWS_WITH(
torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
// ADInplaceOrView(fallthrough), CPU
"Inference tensors cannot be saved for backward.")
}
}
}
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor c;
@ -335,22 +364,29 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
}
if (requires_grad) {
ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU
"Inference tensors cannot be saved for backward.");
ASSERT_THROWS_WITH(
a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode,
// CPU
"Inference tensors cannot be saved for backward.");
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), ADInplaceOrView, CPU
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
ASSERT_THROWS_WITH(
torch::mul_out(
/*out=*/c, s, s), // go through kernels: VariableType(ERROR!),
// ADInplaceOrView, CPU
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
} else {
a.mul_(c);
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, ADInplaceOrView(ERROR!), CPU
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(
torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType,
// ADInplaceOrView(ERROR!), CPU
"Inplace update to inference tensor outside InferenceMode is not allowed");
}
}
}
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor c;
{
@ -358,32 +394,36 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
c = torch::ones({1, 2, 3});
}
// view_as is a composite op which calls view() with only one tensor argument.
// So there isn't a mixed inference tensor and normal tensor inputs for view ops.
torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU
// view_as is a composite op which calls view() with only one tensor
// argument. So there isn't a mixed inference tensor and normal tensor
// inputs for view ops.
torch::Tensor tmp1 =
c.view_as(s); // go through kernels: ADInplaceOrView, CPU
ASSERT_TRUE(tmp1.is_inference());
ASSERT_FALSE(tmp1.requires_grad());
// This is fine since it's equivalent as s.view(c.sizes()) which
// isn't a mixed input scenario.
torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
torch::Tensor tmp2 =
s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
ASSERT_FALSE(tmp2.is_inference());
ASSERT_EQ(tmp2.requires_grad(), requires_grad);
}
}
TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor view_out;
{
InferenceMode guard;
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
}
if (requires_grad) {
ASSERT_THROWS_WITH(inplace_op(view_out),
"A view was created in inference mode and is being modified inplace")
ASSERT_THROWS_WITH(
inplace_op(view_out),
"A view was created in inference mode and is being modified inplace")
} else {
inplace_op(view_out);
}
@ -391,18 +431,19 @@ TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
}
TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor view_out;
{
InferenceMode guard;
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
}
inplace_op(a);
if (requires_grad) {
ASSERT_THROWS_WITH(view_out.grad_fn(),
"A view was created in inference mode and its base or another view of its base has been modified inplace");
ASSERT_THROWS_WITH(
view_out.grad_fn(),
"A view was created in inference mode and its base or another view of its base has been modified inplace");
} else {
view_out.grad_fn();
}
@ -416,14 +457,16 @@ TEST(InferenceModeTest, TestCreationMetaPropagation) {
InferenceMode guard;
b = s.view_as(s);
}
ASSERT_THROWS_WITH(b.add_(1),
"A view was created in inference mode and is being modified inplace");
ASSERT_THROWS_WITH(
b.add_(1),
"A view was created in inference mode and is being modified inplace");
{
AutoGradMode mode(false);
c = b.view_as(b);
}
ASSERT_THROWS_WITH(c.add_(1),
"A view was created in inference mode and is being modified inplace");
ASSERT_THROWS_WITH(
c.add_(1),
"A view was created in inference mode and is being modified inplace");
}
TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
@ -437,20 +480,22 @@ TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
s = s.view_as(s);
c = s.split_with_sizes({1, 1});
}
for (auto& b_el: b) {
for (auto& b_el : b) {
assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
ASSERT_THROWS_WITH(b_el.add_(1),
"A view was created in inference mode and is being modified inplace");
ASSERT_THROWS_WITH(
b_el.add_(1),
"A view was created in inference mode and is being modified inplace");
}
for (auto& c_el: c) {
for (auto& c_el : c) {
assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
ASSERT_THROWS_WITH(c_el.add_(1),
"A view was created in inference mode and is being modified inplace");
ASSERT_THROWS_WITH(
c_el.add_(1),
"A view was created in inference mode and is being modified inplace");
}
}
TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
for (bool requires_grad: {true, false}) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor t;
{
@ -461,8 +506,9 @@ TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
ASSERT_FALSE(t.requires_grad());
}
ASSERT_THROWS_WITH(t.copy_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(
t.copy_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
}
}
@ -473,8 +519,9 @@ TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
t = torch::ones({1, 2, 3});
}
t.set_requires_grad(false);
ASSERT_THROWS_WITH(t.set_requires_grad(true),
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
ASSERT_THROWS_WITH(
t.set_requires_grad(true),
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
}
TEST(InferenceModeTest, TestAccessVersionCounter) {
@ -482,19 +529,23 @@ TEST(InferenceModeTest, TestAccessVersionCounter) {
{
InferenceMode guard;
t = torch::ones({1, 2, 3});
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
"Inference tensors do not track version counter.");
ASSERT_THROWS_WITH(
t.unsafeGetTensorImpl()->version_counter().current_version(),
"Inference tensors do not track version counter.");
t.unsafeGetTensorImpl()->bump_version();
}
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
"Inference tensors do not track version counter.");
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(),
"Inplace update to inference tensor outside InferenceMode is not allowed.");
ASSERT_THROWS_WITH(
t.unsafeGetTensorImpl()->version_counter().current_version(),
"Inference tensors do not track version counter.");
ASSERT_THROWS_WITH(
t.unsafeGetTensorImpl()->bump_version(),
"Inplace update to inference tensor outside InferenceMode is not allowed.");
// Suggested workaround
torch::Tensor c = t.clone();
uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
c.unsafeGetTensorImpl()->bump_version();
ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
ASSERT_EQ(
c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
}
TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
@ -511,11 +562,13 @@ TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
}
s.copy_(t);
s.add_(t);
ASSERT_THROWS_WITH(t.copy_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(
t.copy_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(t.add_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
ASSERT_THROWS_WITH(
t.add_(s),
"Inplace update to inference tensor outside InferenceMode is not allowed");
}
TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
@ -552,18 +605,27 @@ TEST(InferenceModeTest, TestComplexViewInNormalMode) {
TEST(InferenceModeTest, TestCustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
static Variable forward(
AutogradContext* ctx,
Variable var1,
int mul,
Variable var2) {
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul*var2 + var1*var2;
return var1 + mul * var2 + var1 * var2;
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
variable_list output = {
grad_output[0] + grad_output[0] * var2,
Variable(),
grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
@ -586,5 +648,6 @@ TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
WarningCapture warnings;
at::AutoNonVariableTypeMode guard;
ASSERT_TRUE(
warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos);
warnings.str().find("AutoNonVariableTypeMode is deprecated") !=
std::string::npos);
}