[doc] improve code in fake tensor doc (#140329)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140329
Approved by: https://github.com/soulitzer
This commit is contained in:
Tongzhou Wang
2024-11-13 05:14:54 +00:00
committed by PyTorch MergeBot
parent d6b3ad4de2
commit 4c6eebf4e2

View File

@ -69,7 +69,7 @@ PT2 pre-AOTAutograd usage (this is unusual, you probably don't want to do this):
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
... do stuff with the fake args, if needed ...
... # do stuff with the fake args, if needed ...
detect_fake_mode will search a number of locations to try to find "the" fake tensor mode associated with the lifecycle. Typically it will be pulled off of the tracing context.
@ -77,6 +77,7 @@ PT2 post-AOTAutograd usage:
.. code:: python
# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
@ -89,14 +90,13 @@ Other useful stuff:
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
# fake mode is disabled here, you can do real tensor compute
... # fake mode is disabled here, you can do real tensor compute
When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode.
.. code:: python
FakeTensorProp
from torch.fx.passes.fake_tensor_prop
import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
@ -116,7 +116,7 @@ Originally, FakeTensorMode would not automatically fakeify real tensors if you t
.. code:: python
with FakeTensorMode():
real_tensor.t_()
real_tensor.t_()
What should this code do? It would be surprising if we actually modified the metadata on the real tensor. But at the same time, there isn't any obvious opportunity to create a FakeTensor. So we conservatively decided to make this raise an error: "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. Please convert all Tensors to FakeTensors first."