mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Improve torch.inference_mode docs and error message (#161164)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161164 Approved by: https://github.com/sfc-gh-sbekman, https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2db293abc
commit
e06d1d6610
@ -210,29 +210,37 @@ class set_grad_enabled(_DecoratorContextManager):
|
||||
|
||||
|
||||
class inference_mode(_DecoratorContextManager):
|
||||
r"""Context-manager that enables or disables inference mode.
|
||||
r"""Context manager that enables or disables inference mode.
|
||||
|
||||
InferenceMode is a context manager analogous to :class:`~no_grad`
|
||||
to be used when you are certain your operations will have no interactions
|
||||
with autograd (e.g., model training). Code run under this mode gets better
|
||||
performance by disabling view tracking and version counter bumps. Note that
|
||||
unlike some other mechanisms that locally enable or disable grad,
|
||||
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
|
||||
InferenceMode is analogous to :class:`~no_grad` and should be used
|
||||
when you are certain your operations will not interact with autograd
|
||||
(e.g., during data loading or model evaluation). Compared to
|
||||
:class:`~no_grad`, it removes additional overhead by disabling view
|
||||
tracking and version counter bumps. It is also more restrictive, in
|
||||
that tensors created in this mode cannot be used in computations
|
||||
recorded by autograd.
|
||||
|
||||
This context manager is thread local; it will not affect computation
|
||||
This context manager is thread-local; it does not affect computation
|
||||
in other threads.
|
||||
|
||||
Also functions as a decorator.
|
||||
|
||||
.. note::
|
||||
Inference mode is one of several mechanisms that can enable or
|
||||
disable gradients locally see :ref:`locally-disable-grad-doc` for
|
||||
more information on how they compare.
|
||||
Inference mode is one of several mechanisms that can locally enable
|
||||
or disable gradients. See :ref:`locally-disable-grad-doc` for a
|
||||
comparison. If avoiding the use of tensors created in inference mode
|
||||
in autograd-tracked regions is difficult, consider benchmarking your
|
||||
code with and without inference mode to weigh the performance benefits
|
||||
against the trade-offs. You can always use :class:`~no_grad` instead.
|
||||
|
||||
.. note::
|
||||
Unlike some other mechanisms that locally enable or disable grad,
|
||||
entering inference_mode also disables :ref:`forward-mode AD <forward-mode-ad>`.
|
||||
|
||||
Args:
|
||||
mode (bool or function): Either a boolean flag whether to enable or
|
||||
disable inference mode or a Python function to decorate with
|
||||
inference mode enabled
|
||||
mode (bool or function): Either a boolean flag to enable or disable
|
||||
inference mode, or a Python function to decorate with inference
|
||||
mode enabled.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
|
||||
|
@ -39,8 +39,11 @@ SavedVariable::SavedVariable(
|
||||
// follow.
|
||||
TORCH_CHECK(
|
||||
!variable.is_inference(),
|
||||
"Inference tensors cannot be saved for backward. To work around "
|
||||
"you can make a clone to get a normal tensor and use it in autograd.")
|
||||
"Inference tensors cannot be saved for backward. Please do not use "
|
||||
"Tensors created in inference mode in computation tracked by autograd. "
|
||||
"To work around this, you can make a clone to get a normal tensor and "
|
||||
"use it in autograd, or use `torch.no_grad()` instead of "
|
||||
"`torch.inference_mode()`.");
|
||||
|
||||
was_default_constructed_ = false;
|
||||
saved_version_ = variable._version();
|
||||
|
Reference in New Issue
Block a user