mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Also adds `TORCH_WARN_WITH` and `TORCH_WARN_DEPRECATION` macros Part of #72948 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84101 Approved by: https://github.com/albanD
658 lines
21 KiB
C++
658 lines
21 KiB
C++
#include <gtest/gtest.h>
|
|
#include <test/cpp/api/support.h>
|
|
#include <torch/script.h>
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::test;
|
|
|
|
namespace {
|
|
torch::Tensor functional_op(torch::Tensor& x) {
|
|
return x * x;
|
|
}
|
|
|
|
void inplace_op(torch::Tensor& x) {
|
|
x.mul_(1);
|
|
}
|
|
|
|
torch::Tensor view_op(torch::Tensor& x) {
|
|
return x.view({2, 3});
|
|
}
|
|
|
|
/*
|
|
Only the following combos of Autograd & ADInplaceOrView keys on tensors are
|
|
valid:
|
|
- Autograd=true, ADInplaceOrView=true (normal tensor)
|
|
- Autograd=false, ADInplaceOrView=false (inference tensor)
|
|
Tensors created in InferenceMode are mostly inference tensors. The only
|
|
exception is that view of normal tensors created in InferenceMode still
|
|
produce normal tensor.
|
|
*/
|
|
void assert_TLS_states(bool inference_mode) {
|
|
ASSERT_EQ(InferenceMode::is_enabled(), inference_mode);
|
|
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(
|
|
c10::DispatchKey::ADInplaceOrView));
|
|
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(
|
|
c10::autograd_dispatch_keyset));
|
|
ASSERT_EQ(
|
|
c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset),
|
|
inference_mode);
|
|
ASSERT_EQ(
|
|
c10::impl::tls_is_dispatch_key_included(
|
|
c10::DispatchKey::ADInplaceOrView),
|
|
!inference_mode);
|
|
ASSERT_EQ(GradMode::is_enabled(), !inference_mode);
|
|
}
|
|
} // namespace
|
|
|
|
TEST(InferenceModeTest, TestTLSState) {
|
|
assert_TLS_states(false);
|
|
{
|
|
InferenceMode guard;
|
|
assert_TLS_states(true);
|
|
{
|
|
InferenceMode guard(false);
|
|
assert_TLS_states(false);
|
|
}
|
|
assert_TLS_states(true);
|
|
}
|
|
assert_TLS_states(false);
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorCreation) {
|
|
{
|
|
InferenceMode guard;
|
|
// New tensor created through constructors are inference tensors.
|
|
torch::Tensor c = torch::ones({1, 2, 3});
|
|
ASSERT_FALSE(c.requires_grad());
|
|
ASSERT_TRUE(c.is_inference());
|
|
|
|
// requires_grad doesn't change inference tensor behavior inside
|
|
// InferenceMode.
|
|
torch::Tensor tmp = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
ASSERT_TRUE(tmp.requires_grad());
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
|
|
tmp = torch::ones({1, 2, 3}).set_requires_grad(false);
|
|
ASSERT_FALSE(tmp.requires_grad());
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestExistingAutogradSession) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
torch::Tensor a = s.clone();
|
|
|
|
// Save `a` in an existing autograd session
|
|
torch::Tensor out = a * a;
|
|
{
|
|
InferenceMode guard;
|
|
inplace_op(a);
|
|
}
|
|
// Performing backward should trigger error since `a`'s version has been
|
|
// bumped.
|
|
ASSERT_THROWS_WITH(
|
|
out.backward(torch::ones_like(out)),
|
|
"one of the variables needed for gradient computation has been modified by an inplace operation")
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeFunctionalOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
torch::Tensor func_out = functional_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(func_out.is_inference());
|
|
ASSERT_FALSE(func_out.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeInplaceOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
inplace_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(c.is_inference());
|
|
ASSERT_EQ(c.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInInferenceModeViewOp) {
|
|
c10::InferenceMode guard;
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
|
|
torch::Tensor view_out = view_op(c); // go through kernels: CPU
|
|
ASSERT_TRUE(view_out.is_inference());
|
|
// Note this is different from NoGradMode but makes sense.
|
|
ASSERT_FALSE(view_out.requires_grad());
|
|
ASSERT_FALSE(view_out.is_view());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeFunctionalOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad : {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor =
|
|
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
|
|
// Due to issue #54614, this might run slower compared to InferenceMode
|
|
// since intermediate tensors are normal tensors, and they might dispatch to
|
|
// VariableType kernels. This is fine since users can easily fix it by
|
|
// moving it inside InferenceMode block.
|
|
torch::Tensor tmp =
|
|
functional_op(inference_tensor); // go through kernels:
|
|
// ADInplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
ASSERT_FALSE(tmp.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeInplaceOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad : {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor =
|
|
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
ASSERT_THROWS_WITH(
|
|
inplace_op(
|
|
inference_tensor), // go through kernels: ADInplaceOrView, CPU
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInferenceTensorInNormalModeViewOp) {
|
|
torch::Tensor inference_tensor;
|
|
for (bool requires_grad : {true, false}) {
|
|
{
|
|
InferenceMode guard;
|
|
inference_tensor =
|
|
torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
torch::Tensor out =
|
|
view_op(inference_tensor); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_TRUE(out.is_inference());
|
|
ASSERT_FALSE(out.requires_grad());
|
|
ASSERT_FALSE(out.is_view());
|
|
ASSERT_TRUE(out.is_leaf());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInInferenceMode) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
|
|
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(a.is_inference());
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
// inplace -> inplace
|
|
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(a.is_inference());
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
// inplace -> inplace -> view
|
|
torch::Tensor view_out =
|
|
view_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(view_out.is_inference());
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorInplaceOutputInNormalMode) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
|
|
inplace_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(a.is_inference());
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
}
|
|
|
|
torch::Tensor tmp = functional_op(a); // go through kernels: VariableType,
|
|
// ADInplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
|
|
inplace_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
|
|
ASSERT_FALSE(a.is_inference());
|
|
ASSERT_EQ(a.requires_grad(), requires_grad);
|
|
|
|
tmp = view_op(a); // go through kernels: VariableType, ADInplaceOrView, CPU
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorViewOutputInInferenceMode) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out, tmp;
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
// View ops on normal tensor produce normal tensors as output.
|
|
// - For view ops it has both dispatch keys since due to the way we create
|
|
// view Tensors in alias_with_sizes_and_strides:
|
|
// ```
|
|
// auto impl = c10::make_intrusive<TensorImpl>(
|
|
// Storage(self.storage()), self.key_set(), self.dtype());
|
|
// ```
|
|
// In addition, these view output tensors are normal in the sense they
|
|
// have both Autograd and ADInplaceOrView keys. But they're still
|
|
// special since they'll have CreationMeta::INFERENCE_MODE. In other
|
|
// words they behave exactly the same as a view tensor created in
|
|
// no_grad mode.
|
|
|
|
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(view_out.is_inference());
|
|
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(view_out.is_leaf());
|
|
|
|
// view -> view
|
|
tmp = view_op(view_out); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(tmp.is_leaf());
|
|
|
|
// view -> view -> inplace
|
|
inplace_op(tmp); // kernels: ADInplaceOrView, CPU
|
|
assert_tensor_creation_meta(tmp, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(tmp.is_leaf());
|
|
ASSERT_EQ(a._version(), tmp._version());
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestNormalTensorViewOutputInNormalMode) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out, tmp;
|
|
|
|
{
|
|
c10::InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_FALSE(view_out.is_inference());
|
|
assert_tensor_creation_meta(view_out, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_EQ(view_out.requires_grad(), requires_grad);
|
|
ASSERT_TRUE(view_out.is_leaf());
|
|
}
|
|
|
|
tmp = functional_op(view_out);
|
|
ASSERT_FALSE(view_out.is_inference());
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(
|
|
inplace_op(view_out), // go through kernels: VariableType,
|
|
// ADInplaceOrView, CPU
|
|
"A view was created in inference mode and is being modified inplace")
|
|
} else {
|
|
inplace_op(view_out);
|
|
}
|
|
|
|
tmp = view_op(view_out);
|
|
ASSERT_FALSE(view_out.is_inference());
|
|
ASSERT_EQ(tmp.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
}
|
|
|
|
// add(Tensor, Tensor) is safe with inference tensor since it doesn't save
|
|
// any variable for backward.
|
|
torch::Tensor out = c.add(s); // go through kernels: VariableType,
|
|
// ADInplaceOrView(fallthrough), CPU
|
|
ASSERT_FALSE(out.is_inference());
|
|
ASSERT_EQ(out.requires_grad(), requires_grad);
|
|
if (requires_grad) {
|
|
// leaf inference tensor with requires_grad=true can still have gradient.
|
|
// Note this behavior is different from NoGradMode which has empty grad.
|
|
out.backward(torch::ones_like(out));
|
|
assert_tensor_equal(c.grad(), torch::ones_like(c));
|
|
}
|
|
|
|
if (requires_grad) {
|
|
// mul(self, other) saves variable when requires_grad=true
|
|
ASSERT_THROWS_WITH(
|
|
c.mul(s), "Inference tensors cannot be saved for backward.");
|
|
|
|
// Inference tensor in TensorList input
|
|
// stack does not capture anymore, so disabled
|
|
// TODO: find alternative Function that captures a list (maybe custom fn)
|
|
/*
|
|
std::vector<torch::Tensor> inputs = {s, c};
|
|
ASSERT_THROWS_WITH(
|
|
torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
|
|
// ADInplaceOrView(fallthrough), CPU
|
|
"Inference tensors cannot be saved for backward.")
|
|
*/
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorInplaceOp) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3});
|
|
}
|
|
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(
|
|
a.mul_(c), // go through kernels: VariableType(ERROR!), InferenceMode,
|
|
// CPU
|
|
"Inference tensors cannot be saved for backward.");
|
|
|
|
ASSERT_THROWS_WITH(
|
|
torch::mul_out(
|
|
/*out=*/c, s, s), // go through kernels: VariableType(ERROR!),
|
|
// ADInplaceOrView, CPU
|
|
"out=... arguments don't support automatic differentiation, but one of the arguments requires grad")
|
|
} else {
|
|
a.mul_(c);
|
|
|
|
ASSERT_THROWS_WITH(
|
|
torch::mul_out(/*out=*/c, s, s), // go through kernels: VariableType,
|
|
// ADInplaceOrView(ERROR!), CPU
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestMixInferenceAndNormalTensorViewOp) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor c;
|
|
{
|
|
InferenceMode guard;
|
|
c = torch::ones({1, 2, 3});
|
|
}
|
|
|
|
// view_as is a composite op which calls view() with only one tensor
|
|
// argument. So there isn't a mixed inference tensor and normal tensor
|
|
// inputs for view ops.
|
|
torch::Tensor tmp1 =
|
|
c.view_as(s); // go through kernels: ADInplaceOrView, CPU
|
|
ASSERT_TRUE(tmp1.is_inference());
|
|
ASSERT_FALSE(tmp1.requires_grad());
|
|
|
|
// This is fine since it's equivalent as s.view(c.sizes()) which
|
|
// isn't a mixed input scenario.
|
|
torch::Tensor tmp2 =
|
|
s.view_as(c); // go through kernels: VariableType, ADInplaceOrView, CPU
|
|
ASSERT_FALSE(tmp2.is_inference());
|
|
ASSERT_EQ(tmp2.requires_grad(), requires_grad);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestHandleDirectViewOnRebase) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out;
|
|
{
|
|
InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
}
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(
|
|
inplace_op(view_out),
|
|
"A view was created in inference mode and is being modified inplace")
|
|
} else {
|
|
inplace_op(view_out);
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestHandleInDirectViewOnRebase) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor a = s.clone();
|
|
torch::Tensor view_out;
|
|
{
|
|
InferenceMode guard;
|
|
view_out = view_op(a); // go through kernels: ADInplaceOrView, CPU
|
|
}
|
|
inplace_op(a);
|
|
if (requires_grad) {
|
|
ASSERT_THROWS_WITH(
|
|
view_out.grad_fn(),
|
|
"A view was created in inference mode and its base or another view of its base has been modified inplace");
|
|
} else {
|
|
view_out.grad_fn();
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestCreationMetaPropagation) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(true);
|
|
torch::Tensor b, c;
|
|
{
|
|
InferenceMode guard;
|
|
b = s.view_as(s);
|
|
}
|
|
ASSERT_THROWS_WITH(
|
|
b.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
{
|
|
AutoGradMode mode(false);
|
|
c = b.view_as(b);
|
|
}
|
|
ASSERT_THROWS_WITH(
|
|
c.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestCreationMetaPropagationInput) {
|
|
torch::Tensor s = torch::ones({2, 2, 3}).set_requires_grad(true);
|
|
auto s_view = s.view_as(s);
|
|
std::vector<at::Tensor> b, c;
|
|
{
|
|
InferenceMode guard;
|
|
b = s_view.split_with_sizes({1, 1});
|
|
|
|
s = s.view_as(s);
|
|
c = s.split_with_sizes({1, 1});
|
|
}
|
|
for (auto& b_el : b) {
|
|
assert_tensor_creation_meta(b_el, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_THROWS_WITH(
|
|
b_el.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
}
|
|
for (auto& c_el : c) {
|
|
assert_tensor_creation_meta(c_el, CreationMeta::INFERENCE_MODE);
|
|
ASSERT_THROWS_WITH(
|
|
c_el.add_(1),
|
|
"A view was created in inference mode and is being modified inplace");
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInplaceCopyOnInferenceTensor) {
|
|
for (bool requires_grad : {true, false}) {
|
|
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
t.copy_(s);
|
|
ASSERT_TRUE(t.is_inference());
|
|
ASSERT_FALSE(t.requires_grad());
|
|
}
|
|
|
|
ASSERT_THROWS_WITH(
|
|
t.copy_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestSetRequiresGradInNormalMode) {
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
}
|
|
t.set_requires_grad(false);
|
|
ASSERT_THROWS_WITH(
|
|
t.set_requires_grad(true),
|
|
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestAccessVersionCounter) {
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
ASSERT_THROWS_WITH(
|
|
t.unsafeGetTensorImpl()->version_counter().current_version(),
|
|
"Inference tensors do not track version counter.");
|
|
t.unsafeGetTensorImpl()->bump_version();
|
|
}
|
|
ASSERT_THROWS_WITH(
|
|
t.unsafeGetTensorImpl()->version_counter().current_version(),
|
|
"Inference tensors do not track version counter.");
|
|
ASSERT_THROWS_WITH(
|
|
t.unsafeGetTensorImpl()->bump_version(),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed.");
|
|
// Suggested workaround
|
|
torch::Tensor c = t.clone();
|
|
uint32_t v = c.unsafeGetTensorImpl()->version_counter().current_version();
|
|
c.unsafeGetTensorImpl()->bump_version();
|
|
ASSERT_EQ(
|
|
c.unsafeGetTensorImpl()->version_counter().current_version(), v + 1);
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestInplaceUpdateInferenceTensorWithNormalTensor) {
|
|
torch::Tensor s = torch::ones({1, 2, 3});
|
|
torch::Tensor t;
|
|
{
|
|
InferenceMode guard;
|
|
t = torch::ones({1, 2, 3});
|
|
// Testing both copy_ from VariableTypeManual and add_ from generated code.
|
|
s.copy_(t);
|
|
s.add_(t);
|
|
t.add_(s);
|
|
t.copy_(s);
|
|
}
|
|
s.copy_(t);
|
|
s.add_(t);
|
|
ASSERT_THROWS_WITH(
|
|
t.copy_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
|
|
ASSERT_THROWS_WITH(
|
|
t.add_(s),
|
|
"Inplace update to inference tensor outside InferenceMode is not allowed");
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestComplexViewInInferenceMode) {
|
|
torch::Tensor s = torch::ones({3, 3, 2});
|
|
torch::Tensor t = torch::view_as_complex(s);
|
|
{
|
|
InferenceMode guard;
|
|
torch::Tensor tmp;
|
|
|
|
tmp = torch::view_as_real(t);
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
tmp = torch::view_as_complex(s);
|
|
ASSERT_FALSE(tmp.is_inference());
|
|
|
|
torch::Tensor e = torch::ones({3, 3, 2});
|
|
tmp = torch::view_as_complex(e);
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
tmp = torch::view_as_real(tmp);
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestComplexViewInNormalMode) {
|
|
torch::Tensor s;
|
|
{
|
|
InferenceMode guard;
|
|
s = torch::ones({3, 3, 2});
|
|
}
|
|
torch::Tensor tmp = torch::view_as_complex(s);
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
tmp = torch::view_as_real(tmp);
|
|
ASSERT_TRUE(tmp.is_inference());
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestCustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static Variable forward(
|
|
AutogradContext* ctx,
|
|
Variable var1,
|
|
int mul,
|
|
Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return var1 + mul * var2 + var1 * var2;
|
|
}
|
|
|
|
static variable_list backward(
|
|
AutogradContext* ctx,
|
|
variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {
|
|
grad_output[0] + grad_output[0] * var2,
|
|
Variable(),
|
|
grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
{
|
|
InferenceMode guard;
|
|
torch::Tensor var1 = torch::ones({3, 3}).set_requires_grad(true);
|
|
auto var2 = var1.clone();
|
|
int mul = 2;
|
|
// If InferenceMode didn't set NoGradGuard automatically, this line
|
|
// would error out when trying to save `var1` and `var2` for backward.
|
|
auto y = MyFunction::apply(var1, mul, var2);
|
|
torch::Tensor expected = var1 + mul * var2 + var1 * var2;
|
|
assert_tensor_equal(y, expected);
|
|
}
|
|
}
|
|
|
|
TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
|
|
c10::WarningUtils::WarnAlways warn_always(true);
|
|
WarningCapture warnings;
|
|
at::AutoNonVariableTypeMode guard;
|
|
ASSERT_TRUE(
|
|
warnings.str().find("AutoNonVariableTypeMode is deprecated") !=
|
|
std::string::npos);
|
|
}
|