Move warning from item to specific number conversions (#152709)

Follow up to https://github.com/pytorch/pytorch/pull/143261 to not warn when a plain .item() is done.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152709
Approved by: https://github.com/malfet, https://github.com/ngimel
This commit is contained in:
albanD
2025-05-05 20:46:01 +00:00
committed by PyTorch MergeBot
parent 3bc69cc08d
commit 22d1359bc6
3 changed files with 24 additions and 11 deletions

View File

@ -11,17 +11,11 @@
#include <ATen/ops/item_native.h>
#endif
#include <ATen/core/grad_mode.h>
namespace at::native {
Scalar item(const Tensor& self) {
auto numel = self.sym_numel();
TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar");
if (at::GradMode::is_enabled() && self.requires_grad()) {
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
"Consider using tensor.detach() first.");
}
if (self.is_sparse()) {
if (self._nnz() == 0) return Scalar(0);
if (self.is_coalesced()) return at::_local_scalar_dense(self._values());

View File

@ -43,7 +43,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard,
bytes_to_scalar, parametrize, skipIfMPS, noncontiguous_like,
AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo)
AlwaysWarnTypedStorageRemoval, TEST_WITH_TORCHDYNAMO, xfailIfTorchDynamo, set_warn_always_context)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import (
expectedFailureMeta,
@ -10833,8 +10833,8 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertFalse(torch.cuda.is_bf16_supported())
def test_tensor_with_grad_to_scalar_warning(self) -> None:
with warnings.catch_warnings(record=True) as w:
with (warnings.catch_warnings(record=True) as w,
set_warn_always_context(True)):
warnings.simplefilter("always")
x = torch.tensor(2.0, requires_grad=True)
@ -10847,8 +10847,17 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
str(w[0].message)
)
_ = math.pow(x, 3) # calling it again does not result in a second warning
self.assertEqual(len(w), 1)
def test_tensor_item_no_warning(self):
with (warnings.catch_warnings(record=True) as w,
set_warn_always_context(True)):
warnings.simplefilter("always")
x = torch.tensor(2.0, requires_grad=True)
max(x, 3) # No warning
x.item() # No warning
self.assertEqual(len(w), 0)
# The following block extends TestTorch with negative dim wrapping tests
# FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests

View File

@ -36,6 +36,7 @@
#include "torch/csrc/autograd/generated/python_return_types.h"
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <ATen/FuncTorchTLS.h>
#include "c10/core/Stream.h"
@ -291,6 +292,13 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non
return self.copy_(other, non_blocking);
}
static void maybe_warn_requires_grad(const Tensor & self) {
if (at::GradMode::is_enabled() && self.requires_grad()) {
TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
"Consider using tensor.detach() first.");
}
}
static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -325,6 +333,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
}
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = THPVariable_Unpack(self);
maybe_warn_requires_grad(self_);
return wrap(dispatch_to<double>(self_));
END_HANDLE_TH_ERRORS
}
@ -336,6 +345,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
}
jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = THPVariable_Unpack(self);
maybe_warn_requires_grad(self_);
return wrap(dispatch_to<c10::complex<double>>(self_));
END_HANDLE_TH_ERRORS
}