mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
[lint] autoformat test/cpp and torch/csrc
Let's have some fun. Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ec30a6647
commit
30fb2c4aba
@ -1,40 +1,49 @@
|
||||
#include <torch/script.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <test/cpp/api/support.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
using namespace torch::autograd;
|
||||
using namespace torch::test;
|
||||
|
||||
namespace {
|
||||
torch::Tensor functional_op(torch::Tensor& x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
void inplace_op(torch::Tensor& x) {
|
||||
x.mul_(1);
|
||||
}
|
||||
|
||||
torch::Tensor view_op(torch::Tensor& x) {
|
||||
return x.view({2, 3});
|
||||
}
|
||||
|
||||
/*
|
||||
Only the following combos of Autograd & ADInplaceOrView keys on tensors are valid:
|
||||
- Autograd=true, ADInplaceOrView=true (normal tensor)
|
||||
- Autograd=false, ADInplaceOrView=false (inference tensor)
|
||||
Tensors created in InferenceMode are mostly inference tensors. The only exception
|
||||
is that view of normal tensors created in InferenceMode still produce normal tensor.
|
||||
*/
|
||||
void assert_TLS_states(bool inference_mode) {
|
||||
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
|
||||
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::ADInplaceOrView));
|
||||
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
|
||||
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
|
||||
ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::ADInplaceOrView), !inference_mode);
|
||||
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
|
||||
}
|
||||
torch::Tensor functional_op(torch::Tensor& x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
void inplace_op(torch::Tensor& x) {
|
||||
x.mul_(1);
|
||||
}
|
||||
|
||||
torch::Tensor view_op(torch::Tensor& x) {
|
||||
return x.view({2, 3});
|
||||
}
|
||||
|
||||
/*
|
||||
Only the following combos of Autograd & ADInplaceOrView keys on tensors are
|
||||
valid:
|
||||
- Autograd=true, ADInplaceOrView=true (normal tensor)
|
||||
- Autograd=false, ADInplaceOrView=false (inference tensor)
|
||||
Tensors created in InferenceMode are mostly inference tensors. The only
|
||||
exception is that view of normal tensors created in InferenceMode still
|
||||
produce normal tensor.
|
||||
*/
|
||||
void assert_TLS_states(bool inference_mode) {
|
||||
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
|
||||
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(
|
||||
c10::DispatchKey::ADInplaceOrView));
|
||||
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(
|
||||
c10::autograd_dispatch_keyset));
|
||||
ASSERT_EQ(
|
||||
c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset),
|
||||
inference_mode);
|
||||
ASSERT_EQ(
|
||||
c10::impl::tls_is_dispatch_key_included(
|
||||
c10::DispatchKey::ADInplaceOrView),
|
||||
!inference_mode);
|
||||
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(InferenceModeTest, TestTLSState) {
|
||||
assert_TLS_states(false);
|
||||
{
|
||||
@ -57,7 +66,8 @@ TEST(InferenceModeTest, TestInferenceTensorCreation) {
|
||||
ASSERT_FALSE(c.requires_grad());
|
||||
ASSERT_TRUE(c.is_inference());
|
||||
|
||||
// requires_grad doesn't change inference tensor behavior inside InferenceMode.
|
||||
// requires_grad doesn't change inference tensor behavior inside
|
||||
// InferenceMode.
|
||||
torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
|
||||
ASSERT_TRUE(tmp.requires_grad());
|
||||
ASSERT_TRUE(tmp.is_inference());
|
||||
@ -78,9 +88,11 @@ TEST(InferenceModeTest, TestExistingAutogradSession) {
|
||||
InferenceMode guard;
|
||||
inplace_op(a);
|
||||
}
|
||||
// Performing backward should trigger error since `a`'s version has been bumped.
|
||||
ASSERT_THROWS_WITH(out.backward(torch::ones_like(out)),
|
||||
"one of the variables needed for gradient computation has been modified by an inplace operation")
|
||||
// Performing backward should trigger error since `a`'s version has been
|
||||
// bumped.
|
||||
ASSERT_THROWS_WITH(
|
||||
out.backward(torch::ones_like(out)),
|
||||
"one of the variables needed for gradient computation has been modified by an inplace operation")
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
|
||||
@ -88,7 +100,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
|
||||
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
|
||||
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
|
||||
ASSERT_TRUE(func_out.is_inference());
|
||||
ASSERT_FALSE(func_out.requires_grad());
|
||||
}
|
||||
@ -99,7 +111,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
|
||||
inplace_op(c); // go through kernels: CPU
|
||||
inplace_op(c); // go through kernels: CPU
|
||||
ASSERT_TRUE(c.is_inference());
|
||||
ASSERT_EQ(c.requires_grad(), requires_grad);
|
||||
}
|
||||
@ -110,7 +122,7 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
|
||||
torch::Tensor view_out = view_op(c); // go through kernels: CPU
|
||||
torch::Tensor view_out = view_op(c); // go through kernels: CPU
|
||||
ASSERT_TRUE(view_out.is_inference());
|
||||
// Note this is different from NoGradMode but makes sense.
|
||||
ASSERT_FALSE(view_out.requires_grad());
|
||||
@ -120,17 +132,20 @@ TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
|
||||
|
||||
TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
|
||||
torch::Tensor inference_tensor;
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
{
|
||||
InferenceMode guard;
|
||||
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
inference_tensor =
|
||||
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
}
|
||||
|
||||
// Due to issue #54614, this might run slower compared to InferenceMode since
|
||||
// intermediate tensors are normal tensors, and they might dispatch to VariableType
|
||||
// kernels. This is fine since users can easily fix it by moving
|
||||
// it inside InferenceMode block.
|
||||
torch::Tensor tmp = functional_op(inference_tensor); // go through kernels: ADInplaceOrView(fallthrough), CPU
|
||||
// Due to issue #54614, this might run slower compared to InferenceMode
|
||||
// since intermediate tensors are normal tensors, and they might dispatch to
|
||||
// VariableType kernels. This is fine since users can easily fix it by
|
||||
// moving it inside InferenceMode block.
|
||||
torch::Tensor tmp =
|
||||
functional_op(inference_tensor); // go through kernels:
|
||||
// ADInplaceOrView(fallthrough), CPU
|
||||
ASSERT_FALSE(tmp.is_inference());
|
||||
ASSERT_FALSE(tmp.requires_grad());
|
||||
}
|
||||
@ -138,24 +153,29 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
|
||||
|
||||
TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
|
||||
torch::Tensor inference_tensor;
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
{
|
||||
InferenceMode guard;
|
||||
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
inference_tensor =
|
||||
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
}
|
||||
ASSERT_THROWS_WITH(inplace_op(inference_tensor), // go through kernels: ADInplaceOrView, CPU
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
ASSERT_THROWS_WITH(
|
||||
inplace_op(
|
||||
inference_tensor), // go through kernels: ADInplaceOrView, CPU
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
|
||||
torch::Tensor inference_tensor;
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
{
|
||||
InferenceMode guard;
|
||||
inference_tensor = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
inference_tensor =
|
||||
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
}
|
||||
torch::Tensor out = view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
|
||||
torch::Tensor out =
|
||||
view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_TRUE(out.is_inference());
|
||||
ASSERT_FALSE(out.requires_grad());
|
||||
ASSERT_FALSE(out.is_view());
|
||||
@ -164,24 +184,25 @@ TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
|
||||
{
|
||||
c10::InferenceMode guard;
|
||||
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(a.is_inference());
|
||||
ASSERT_EQ(a.requires_grad(), requires_grad);
|
||||
|
||||
// inplace -> inplace
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(a.is_inference());
|
||||
ASSERT_EQ(a.requires_grad(), requires_grad);
|
||||
|
||||
// inplace -> inplace -> view
|
||||
torch::Tensor view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
torch::Tensor view_out =
|
||||
view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(view_out.is_inference());
|
||||
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
||||
}
|
||||
@ -189,19 +210,20 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
|
||||
{
|
||||
c10::InferenceMode guard;
|
||||
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(a.is_inference());
|
||||
ASSERT_EQ(a.requires_grad(), requires_grad);
|
||||
}
|
||||
|
||||
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
|
||||
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType,
|
||||
// ADInplaceOrView(fallthrough), CPU
|
||||
ASSERT_FALSE(tmp.is_inference());
|
||||
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
||||
|
||||
@ -209,14 +231,14 @@ TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
|
||||
ASSERT_FALSE(a.is_inference());
|
||||
ASSERT_EQ(a.requires_grad(), requires_grad);
|
||||
|
||||
tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
|
||||
tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(tmp.is_inference());
|
||||
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
torch::Tensor view_out, tmp;
|
||||
@ -231,25 +253,26 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
|
||||
// Storage(self.storage()), self.key_set(), self.dtype());
|
||||
// ```
|
||||
// In addition, these view output tensors are normal in the sense they
|
||||
// have both Autograd and ADInplaceOrView keys. But they're still special
|
||||
// since they'll have CreationMeta::INFERENCE_MODE. In other words they behave
|
||||
// exactly the same as a view tensor created in no_grad mode.
|
||||
// have both Autograd and ADInplaceOrView keys. But they're still
|
||||
// special since they'll have CreationMeta::INFERENCE_MODE. In other
|
||||
// words they behave exactly the same as a view tensor created in
|
||||
// no_grad mode.
|
||||
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(view_out.is_inference());
|
||||
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
||||
ASSERT_TRUE(view_out.is_leaf());
|
||||
|
||||
// view -> view
|
||||
tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
|
||||
tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(tmp.is_inference());
|
||||
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
||||
ASSERT_TRUE(tmp.is_leaf());
|
||||
|
||||
// view -> view -> inplace
|
||||
inplace_op(tmp); // kernels: ADInplaceOrView, CPU
|
||||
inplace_op(tmp); // kernels: ADInplaceOrView, CPU
|
||||
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_FALSE(tmp.is_inference());
|
||||
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
||||
@ -260,14 +283,14 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
torch::Tensor view_out, tmp;
|
||||
|
||||
{
|
||||
c10::InferenceMode guard;
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(view_out.is_inference());
|
||||
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
||||
@ -279,8 +302,10 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
|
||||
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
||||
|
||||
if (requires_grad) {
|
||||
ASSERT_THROWS_WITH(inplace_op(view_out), // go through kernels: VariableType, ADInplaceOrView, CPU
|
||||
"A view was created in inference mode and is being modified inplace")
|
||||
ASSERT_THROWS_WITH(
|
||||
inplace_op(view_out), // go through kernels: VariableType,
|
||||
// ADInplaceOrView, CPU
|
||||
"A view was created in inference mode and is being modified inplace")
|
||||
} else {
|
||||
inplace_op(view_out);
|
||||
}
|
||||
@ -292,7 +317,7 @@ TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor c;
|
||||
{
|
||||
@ -300,8 +325,10 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
||||
c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
}
|
||||
|
||||
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save any variable for backward.
|
||||
torch::Tensor out = c.add(s); // go through kernels: VariableType, ADInplaceOrView(fallthrough), CPU
|
||||
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save
|
||||
// any variable for backward.
|
||||
torch::Tensor out = c.add(s); // go through kernels: VariableType,
|
||||
// ADInplaceOrView(fallthrough), CPU
|
||||
ASSERT_FALSE(out.is_inference());
|
||||
ASSERT_EQ(out.requires_grad(), requires_grad);
|
||||
if (requires_grad) {
|
||||
@ -313,19 +340,21 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
||||
|
||||
if (requires_grad) {
|
||||
// mul(self, other) saves variable when requires_grad=true
|
||||
ASSERT_THROWS_WITH(c.mul(s),
|
||||
"Inference tensors cannot be saved for backward.");
|
||||
ASSERT_THROWS_WITH(
|
||||
c.mul(s), "Inference tensors cannot be saved for backward.");
|
||||
|
||||
// Inference tensor in TensorList input
|
||||
std::vector<torch::Tensor> inputs = {s, c};
|
||||
ASSERT_THROWS_WITH(torch::stack(inputs), // go through kernels: VariableType(ERROR)!, ADInplaceOrView(fallthrough), CPU
|
||||
"Inference tensors cannot be saved for backward.")
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
|
||||
// ADInplaceOrView(fallthrough), CPU
|
||||
"Inference tensors cannot be saved for backward.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
torch::Tensor c;
|
||||
@ -335,22 +364,29 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
|
||||
}
|
||||
|
||||
if (requires_grad) {
|
||||
ASSERT_THROWS_WITH(a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode, CPU
|
||||
"Inference tensors cannot be saved for backward.");
|
||||
ASSERT_THROWS_WITH(
|
||||
a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode,
|
||||
// CPU
|
||||
"Inference tensors cannot be saved for backward.");
|
||||
|
||||
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType(ERROR!), ADInplaceOrView, CPU
|
||||
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::mul_out(
|
||||
/*out=*/c, s, s), // go through kernels: VariableType(ERROR!),
|
||||
// ADInplaceOrView, CPU
|
||||
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
|
||||
} else {
|
||||
a.mul_(c);
|
||||
|
||||
ASSERT_THROWS_WITH(torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType, ADInplaceOrView(ERROR!), CPU
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType,
|
||||
// ADInplaceOrView(ERROR!), CPU
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor c;
|
||||
{
|
||||
@ -358,32 +394,36 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
|
||||
c = torch::ones({1, 2, 3});
|
||||
}
|
||||
|
||||
// view_as is a composite op which calls view() with only one tensor argument.
|
||||
// So there isn't a mixed inference tensor and normal tensor inputs for view ops.
|
||||
torch::Tensor tmp1 = c.view_as(s); // go through kernels: ADInplaceOrView, CPU
|
||||
// view_as is a composite op which calls view() with only one tensor
|
||||
// argument. So there isn't a mixed inference tensor and normal tensor
|
||||
// inputs for view ops.
|
||||
torch::Tensor tmp1 =
|
||||
c.view_as(s); // go through kernels: ADInplaceOrView, CPU
|
||||
ASSERT_TRUE(tmp1.is_inference());
|
||||
ASSERT_FALSE(tmp1.requires_grad());
|
||||
|
||||
// This is fine since it's equivalent as s.view(c.sizes()) which
|
||||
// isn't a mixed input scenario.
|
||||
torch::Tensor tmp2 = s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
|
||||
torch::Tensor tmp2 =
|
||||
s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
|
||||
ASSERT_FALSE(tmp2.is_inference());
|
||||
ASSERT_EQ(tmp2.requires_grad(), requires_grad);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
torch::Tensor view_out;
|
||||
{
|
||||
InferenceMode guard;
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
}
|
||||
if (requires_grad) {
|
||||
ASSERT_THROWS_WITH(inplace_op(view_out),
|
||||
"A view was created in inference mode and is being modified inplace")
|
||||
ASSERT_THROWS_WITH(
|
||||
inplace_op(view_out),
|
||||
"A view was created in inference mode and is being modified inplace")
|
||||
} else {
|
||||
inplace_op(view_out);
|
||||
}
|
||||
@ -391,18 +431,19 @@ TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor a = s.clone();
|
||||
torch::Tensor view_out;
|
||||
{
|
||||
InferenceMode guard;
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
||||
}
|
||||
inplace_op(a);
|
||||
if (requires_grad) {
|
||||
ASSERT_THROWS_WITH(view_out.grad_fn(),
|
||||
"A view was created in inference mode and its base or another view of its base has been modified inplace");
|
||||
ASSERT_THROWS_WITH(
|
||||
view_out.grad_fn(),
|
||||
"A view was created in inference mode and its base or another view of its base has been modified inplace");
|
||||
} else {
|
||||
view_out.grad_fn();
|
||||
}
|
||||
@ -416,14 +457,16 @@ TEST(InferenceModeTest, TestCreationMetaPropagation) {
|
||||
InferenceMode guard;
|
||||
b = s.view_as(s);
|
||||
}
|
||||
ASSERT_THROWS_WITH(b.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
ASSERT_THROWS_WITH(
|
||||
b.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
{
|
||||
AutoGradMode mode(false);
|
||||
c = b.view_as(b);
|
||||
}
|
||||
ASSERT_THROWS_WITH(c.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
ASSERT_THROWS_WITH(
|
||||
c.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
|
||||
@ -437,20 +480,22 @@ TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
|
||||
s = s.view_as(s);
|
||||
c = s.split_with_sizes({1, 1});
|
||||
}
|
||||
for (auto& b_el: b) {
|
||||
for (auto& b_el : b) {
|
||||
assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_THROWS_WITH(b_el.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
ASSERT_THROWS_WITH(
|
||||
b_el.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
}
|
||||
for (auto& c_el: c) {
|
||||
for (auto& c_el : c) {
|
||||
assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
|
||||
ASSERT_THROWS_WITH(c_el.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
ASSERT_THROWS_WITH(
|
||||
c_el.add_(1),
|
||||
"A view was created in inference mode and is being modified inplace");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
|
||||
for (bool requires_grad: {true, false}) {
|
||||
for (bool requires_grad : {true, false}) {
|
||||
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
||||
torch::Tensor t;
|
||||
{
|
||||
@ -461,8 +506,9 @@ TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
|
||||
ASSERT_FALSE(t.requires_grad());
|
||||
}
|
||||
|
||||
ASSERT_THROWS_WITH(t.copy_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.copy_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
}
|
||||
}
|
||||
|
||||
@ -473,8 +519,9 @@ TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
|
||||
t = torch::ones({1, 2, 3});
|
||||
}
|
||||
t.set_requires_grad(false);
|
||||
ASSERT_THROWS_WITH(t.set_requires_grad(true),
|
||||
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.set_requires_grad(true),
|
||||
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestAccessVersionCounter) {
|
||||
@ -482,19 +529,23 @@ TEST(InferenceModeTest, TestAccessVersionCounter) {
|
||||
{
|
||||
InferenceMode guard;
|
||||
t = torch::ones({1, 2, 3});
|
||||
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
|
||||
"Inference tensors do not track version counter.");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.unsafeGetTensorImpl()->version_counter().current_version(),
|
||||
"Inference tensors do not track version counter.");
|
||||
t.unsafeGetTensorImpl()->bump_version();
|
||||
}
|
||||
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->version_counter().current_version(),
|
||||
"Inference tensors do not track version counter.");
|
||||
ASSERT_THROWS_WITH(t.unsafeGetTensorImpl()->bump_version(),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed.");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.unsafeGetTensorImpl()->version_counter().current_version(),
|
||||
"Inference tensors do not track version counter.");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.unsafeGetTensorImpl()->bump_version(),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed.");
|
||||
// Suggested workaround
|
||||
torch::Tensor c = t.clone();
|
||||
uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
|
||||
c.unsafeGetTensorImpl()->bump_version();
|
||||
ASSERT_EQ(c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
|
||||
ASSERT_EQ(
|
||||
c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
|
||||
@ -511,11 +562,13 @@ TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
|
||||
}
|
||||
s.copy_(t);
|
||||
s.add_(t);
|
||||
ASSERT_THROWS_WITH(t.copy_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.copy_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
|
||||
ASSERT_THROWS_WITH(t.add_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
ASSERT_THROWS_WITH(
|
||||
t.add_(s),
|
||||
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
||||
}
|
||||
|
||||
TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
|
||||
@ -552,18 +605,27 @@ TEST(InferenceModeTest, TestComplexViewInNormalMode) {
|
||||
|
||||
TEST(InferenceModeTest, TestCustomFunction) {
|
||||
struct MyFunction : public Function<MyFunction> {
|
||||
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
||||
static Variable forward(
|
||||
AutogradContext* ctx,
|
||||
Variable var1,
|
||||
int mul,
|
||||
Variable var2) {
|
||||
ctx->saved_data["mul"] = mul;
|
||||
ctx->save_for_backward({var1, var2});
|
||||
return var1 + mul*var2 + var1*var2;
|
||||
return var1 + mul * var2 + var1 * var2;
|
||||
}
|
||||
|
||||
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
||||
static variable_list backward(
|
||||
AutogradContext* ctx,
|
||||
variable_list grad_output) {
|
||||
int mul = ctx->saved_data["mul"].toInt();
|
||||
auto saved = ctx->get_saved_variables();
|
||||
auto var1 = saved[0];
|
||||
auto var2 = saved[1];
|
||||
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
||||
variable_list output = {
|
||||
grad_output[0] + grad_output[0] * var2,
|
||||
Variable(),
|
||||
grad_output[0] * mul + grad_output[0] * var1};
|
||||
return output;
|
||||
}
|
||||
};
|
||||
@ -586,5 +648,6 @@ TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
|
||||
WarningCapture warnings;
|
||||
at::AutoNonVariableTypeMode guard;
|
||||
ASSERT_TRUE(
|
||||
warnings.str().find("AutoNonVariableTypeMode is deprecated") != std::string::npos);
|
||||
warnings.str().find("AutoNonVariableTypeMode is deprecated") !=
|
||||
std::string::npos);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user