mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move warning from item to specific number conversions (#152709)
Follow up to https://github.com/pytorch/pytorch/pull/143261 to not warn when a plain .item() is done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152709 Approved by: https://github.com/malfet, https://github.com/ngimel
This commit is contained in:
@ -11,17 +11,11 @@
|
||||
#include <ATen/ops/item_native.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/core/grad_mode.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
Scalar item(const Tensor& self) {
|
||||
auto numel = self.sym_numel();
|
||||
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
|
||||
if (at::GradMode::is_enabled() && self.requires_grad()) {
|
||||
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
|
||||
"Consider using tensor.detach() first.");
|
||||
}
|
||||
if (self.is_sparse()) {
|
||||
if (self._nnz() == 0) return Scalar(0);
|
||||
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());
|
||||
|
||||
@ -43,7 +43,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
||||
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
|
||||
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
|
||||
bytes_to_scalar, parametrize, skipIfMPS, noncontiguous_like,
|
||||
AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo)
|
||||
AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo, set_warn_always_context)
|
||||
from multiprocessing.reduction import ForkingPickler
|
||||
from torch.testing._internal.common_device_type import (
|
||||
expectedFailureMeta,
|
||||
@ -10833,8 +10833,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
self.assertFalse(torch.cuda.is_bf16_supported())
|
||||
|
||||
def test_tensor_with_grad_to_scalar_warning(self) -> None:
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with (warnings.catch_warnings(record=True) as w,
|
||||
set_warn_always_context(True)):
|
||||
warnings.simplefilter("always")
|
||||
|
||||
x = torch.tensor(2.0, requires_grad=True)
|
||||
@ -10847,8 +10847,17 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
str(w[0].message)
|
||||
)
|
||||
|
||||
_ = math.pow(x, 3) # calling it again does not result in a second warning
|
||||
self.assertEqual(len(w), 1)
|
||||
def test_tensor_item_no_warning(self):
|
||||
with (warnings.catch_warnings(record=True) as w,
|
||||
set_warn_always_context(True)):
|
||||
warnings.simplefilter("always")
|
||||
|
||||
x = torch.tensor(2.0, requires_grad=True)
|
||||
max(x, 3) # No warning
|
||||
x.item() # No warning
|
||||
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
|
||||
# The following block extends TestTorch with negative dim wrapping tests
|
||||
# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
#include "torch/csrc/autograd/generated/python_return_types.h"
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/FuncTorchTLS.h>
|
||||
#include "c10/core/Stream.h"
|
||||
|
||||
@ -291,6 +292,13 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non
|
||||
return self.copy_(other, non_blocking);
|
||||
}
|
||||
|
||||
static void maybe_warn_requires_grad(const Tensor & self) {
|
||||
if (at::GradMode::is_enabled() && self.requires_grad()) {
|
||||
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
|
||||
"Consider using tensor.detach() first.");
|
||||
}
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -325,6 +333,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
|
||||
}
|
||||
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
maybe_warn_requires_grad(self_);
|
||||
return wrap(dispatch_to<double>(self_));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -336,6 +345,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
|
||||
}
|
||||
jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
maybe_warn_requires_grad(self_);
|
||||
return wrap(dispatch_to<c10::complex<double>>(self_));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user