Summary:
When we divide a FakeTensor by an integer using the fast op implementation, the type promotion should be `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` so we get a float when dividing an int FakeTensor by an integer.
```
FAST = get_fast_op_impls()
fast_div = FAST[torch.ops.aten.div.Tensor]
fast_div(fake_tensor, some_int)
```
Test Plan:
```
python test/test_fake_tensor.py -k test_fast_div
```
Differential Revision: D72667430
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150874
Approved by: https://github.com/angelayi
Fixing this is actually a bit annoying:
(1) FakeTensorMode sees a function where all of its inputs are real tensors, so it tries to run the real compute before converting the output to a FakeTensor
(2) we don't actually want this, because the "real compute" is support to error normally, when you do `meta_tensor.to(device='cpu')`. Instead, we want FakeTensor to actually skip constant prop and run the normal FakeTensor implementation, which will not error
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146729
Approved by: https://github.com/zou3519, https://github.com/SherlockNoMad, https://github.com/albanD
ghstack dependencies: #146642
context here: https://fb.workplace.com/groups/326136610199609/permalink/495389539940981/
This PR is an attempt to make it such that if you create a tensor from an external buffer (using `UntypedStorage.from_buffer(buf)`, we can generate a proper fake tensor for you out of the box.
The annoying bit is that there are not any dispatcher ops to interpose on and change behavior. So instead, I took the manual C binding and tweaked the storage device to be "meta' if we see an active fake mode.
Put "poc" in the title since I... think this is hopefully reasonable, but I can be convinced that it's not :)
```
from torch._subclasses.fake_tensor import FakeTensorMode
import pickle
import io
import torch
from contextlib import nullcontext
use_fake_tensor = True
with FakeTensorMode() if use_fake_tensor else nullcontext():
obj = [1, 2]
f = io.BytesIO()
pickle.Pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
t = torch.ByteTensor(byte_storage)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146642
Approved by: https://github.com/zou3519
Fixes#133605
**Summary**
This PR adds support for FP8 data types to the `index_cuda` op.
It uses `AT_DISPATCH_V2` which is a new macro that can handle arbitrary number of dtypes, as opposed to the old implementations which had a separate macro for each possible number of dtype arguments (e.g. `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3,4,5...}`).
**Test plan**
Updated test `index_cuda_with_cpu` in `test/test_fake_tensor.py` to have cases for all dtypes handled by `index_cuda`, including fp8 dtypes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144747
Approved by: https://github.com/vkuzo
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
Summary:
While testing exportability for PT2 Inference models, we found various cases of invalid op inputs during tracing, for example errors like: `a and b must have same reduction dim`, `expected scalar type Long but found Int`, etc. Looking more closely, these happened to due the same few meta kernels & eager kernels producing mismatched outputs upstream (e.g. different output tensor dtype, int output).
Adding checks to catch mismatched outputs in real tensor prop upstream, so errors are raised at the mismatched op, instead of the downstream ops taking them as inputs. Relies a lot on utils from [CrossRefFakeMode](929797dedb/torch/_subclasses/fake_utils.py (L78))
Follow ups: could add more checks, and maybe have a flag to only enable these for cases like draft mode, so perf doesn't suffer?
Test Plan: test_export, test_fake_tensor
Differential Revision: D64210055
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137747
Approved by: https://github.com/zou3519
This pr introduces two changes:
1. Before this pr, the subgraphs output is ([], []), in this pr, we change it to a flattened list for easier codegen and consistency with other control flow operators.
2. Before the PR, the combine_fn of scan takes a sliced input but keep the sliced dimension. For exmaple, suppose xs = torch.randn(3, 4, 5) and we scan over dim 0, the combine_fn looks like:
```
# x.shape = (1, 4, 5) instead of (4, 5)
def combine_fn(carry, x):
...
```
In this PR, we fixed this and also simplify some of the slicing logic.
3. this diff also make sure we always stack ys on fist dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135601
Approved by: https://github.com/zou3519
ghstack dependencies: #135600
Adds support for SymInts in the FakeTensor cache.
A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.
Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 68137
INFO: cache_misses: 837
INFO: cache_bypasses:
INFO: symbolic shape: 48224
INFO: CompositeImplicitAutograd: 917
INFO: non-fake tensor: 70
INFO: non-FakeTensor output: 62
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
and after:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 88187
INFO: cache_misses: 14233
INFO: cache_bypasses:
INFO: CompositeImplicitAutograd: 1037
INFO: non-FakeTensor output: 602
INFO: non-fake tensor: 70
INFO: unsafe view: 36
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596
Approved by: https://github.com/eellison
ghstack dependencies: #131014, #129780
Changes:
1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
Changes:
1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7228787720582401/
There a few improvements here, which luckily fix some xfails:
* In generally, it can be unsafe to call operations on Tensors under a `no_dispatch()` mode that is purely trying to disable ambient modes, because this ALSO disables tensor subclass handling. So we test to see if there is a tensor subclass and don't propagate real tensors if that's the case. Another acceptable outcome might be to try to only disable the ambient fake tensor mode, this would help us propagate real tensors through more exotic tensor types, but I'm not going to do it until someone asks for it.
* We're graph breaking for wrapped tensors too late. Pull it up earlier so we do it before we try to muck around with the real tensor.
* I noticed that occasionally when I do `storage.copy_(real_storage)`, the sizes mismatch. Careful code reading suggests that I should just copy in the real data when the tensor was initially allocated, so that's what I do now, eliminating the need for a storage copy.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126281
Approved by: https://github.com/Skylion007
Internal xref:
https://fb.workplace.com/groups/6829516587176185/posts/7211398545654652/
Previously I did it in a crappy way using clone_input in the callback,
but this results in tensors that don't have quite the same
size/stride/storage offset and there was an internal test case where
not having completely accurate information was causing a downstream
problem in propagation. So now I make real tensors as similar to their
fake equivalents as much as possible. Though... I don't bother with
autograd lol.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126175
Approved by: https://github.com/albanD
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.
This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.
I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:
```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```
Potential later follow ups:
* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
This PR fixes an issue presented when calling `aten.alias(int)` raises a TypeError.
```python
import torch
import torch.autograd.forward_ad as fwAD
def f(x):
return 4312491 * x
device = "cpu"
with torch._subclasses.fake_tensor.FakeTensorMode():
with fwAD.dual_level():
x = torch.randn(3, device=device)
y = torch.ones_like(x)
dual = fwAD.make_dual(x, y)
f(dual)
```
The test case above illustrates this bug.
1) `4312491` turns into a tensor that is a wrapped number
2) Forward mode AD calls `aten::alias` internally
3) The wrapped number (`4312491`) becomes a python integer
4) `aten.alias(int)` raises a `TypeError`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124774
Approved by: https://github.com/albanD, https://github.com/zou3519
Also partially fixes#122109
This PR:
- We add a C++ flag (only_lift_cpu_tensors) to toggle the
torch.tensor(1, device='cuda') ctor strategy.
When false (default), it does the current PyTorch behavior
of unconditionally constructing a concrete CUDA tensor then calling
lift_fresh on it. When true, we instead construct a concrete CPU
tensor, call lift_fresh, and then call Tensor.to(device) (under any ambient
modes).
- FakeTensorMode flips this flag depending on if CUDA is available or
not. We don't unconditionally set the flag to True because that is
likely BC-breaking.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124413
Approved by: https://github.com/eellison
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
ghstack dependencies: #122018