This PR adds universal support for ndarray methods. After #100839 each `NumpyNdarrayVariable` should wrap a `torch.Tensor`. This PR adds a `numpy_method_wrapper` which converts the `torch.Tensor` to `torch_np.ndarray` and then call the numpy ndarray method. Then we also try to return a `torch.Tensor` (return as-is if the value is not ndarray-like)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97537
Approved by: https://github.com/ezyang
Issue: #93684
In previous PRs #95849#99560 we redirect `numpy.*`, `<tensor>.numpy()` calls to `torch_np.*` methods and attributes, by creating `NumpyNdarrayVariable` for those calls.
We need to handle `NumpyNdarrayVariable` when graph break happens.
This PR did 2 things:
1. In `codegen.py` we made sure we can reconstruct the value wrapped by `NumpyNdarrayVariable`, to be `torch_np.ndarray` in the stack whenerver we recompiles the subgraph.
2. In `builder.py` we can wrap the value to be `NumpyNdarrayVariable` and save it as graph input.
-----
Starting from commit 6:
## A new design for supporting numpy in dynamo
In short the core concept doesn't change: we still convert `numpy` API calls to `torch_np` API calls. However, instead of wrapping a `torch_np.ndarray` in `NumpyNdarrayVariable`, the new design wraps a `torch.Tensor`.
The reason for doing this change is because we need to keep `torch.Tensor` everywhere in the captured graph, so that it works well with the backend of dynamo. See discussions in https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/142 for details.
### Flow
This is an example showing how do we think about dynamo working on a simple function:
```python
def f(x: torch.Tensor, y: torch.Tensor):
a, b = x.numpy(), y.numpy()
c = np.add(x, y)
return torch.from_numpy(c)
```
```
+------------+ +------------+
torch.Tensor | |numpy.ndarray| |
-------------- .numpy() --------------| |
| | | | +------------------+
+------------+ | numpy.add |numpy.ndarray| |torch.Tensor
+------------+ | --------------| torch.from_numpy --------------
torch.Tensor | |numpy.ndarray| | | |
-------------- .numpy() --------------| | +------------------+
| | | |
+------------+ +------------+
+------------+ +----------------+
torch.Tensor | |torch.Tensor | |
-------------- .detach() --------------| |
| | | | +----------------+ +------------+
+------------+ | |torch_np.ndarray| |torch.Tensor| |torch.Tensor
| torch_np.add -----------------| util.to_tensor -------------| .detach() --------------
+------------+ | | | | | |
torch.Tensor | |torch.Tensor | | +----------------+ +------------+
-------------- .detach() --------------| |
| | | |
+------------+ | +----------------+ |
| wrapper on torch_np.add |
+--------------------------------------------------------+
```
### Approach
`torch_np` APIs can take both `torch_np.ndarray` as well as `torch.Tensor`. What we need to do is to have a wrapper for these APIs to convert the return value back to `torch.Tensor`. This way only the wrapper is showing up in the captured graph, with `torch.Tensor`s as input and `torch.Tensor` as output.
If we have a graph break or we've traced to the end of the program, we need to inspect all the `NumpyNdarrayVariable` in the stack and convert them back to `numpy.ndarray`, to make sure the compiled version is still behaving the same as the eager version.
### Examples
Here's an example of the graph generated:
```python
def fn(x: np.ndarray, y: np.ndarray):
a = x.real
b = y.real
torch._dynamo.graph_break()
return np.add(a, 1), np.add(b, 1)
```
Graph generated:
```
[2023-05-16 10:31:48,737] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
------------- -------------- ---------------------------------------------------------- ---------------------- --------
placeholder l_x_ L_x_ () {}
placeholder l_y_ L_y_ () {}
call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_x_,) {}
call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_y_,) {}
call_function attr_wrapper <function attr_wrapper at 0x12e8693a0> (from_numpy, 'real') {}
call_function attr_wrapper_1 <function attr_wrapper at 0x12e8693a0> (from_numpy_1, 'real') {}
output output output ((),) {}
[2023-05-16 10:31:48,908] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
__compiled_fn_2 <eval_with_key>.1 opcode name target args kwargs
------------- ------------- ---------------------------------------------------------- ------------------------------- --------
placeholder l_a_ L_a_ () {}
placeholder l_b_ L_b_ () {}
call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_a_,) {}
call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_b_,) {}
call_function wrapped_add <Wrapped function <original add>> (from_numpy, 1) {}
call_function wrapped_add_1 <Wrapped function <original add>> (from_numpy_1, 1) {}
output output output ((wrapped_add, wrapped_add_1),) {}
```
### Changes
* `codegen.py`: reconstruct `numpy.ndarray` from `NumpyNdarrayVariable` by adding bytecode to call `utils.to_numpy_helper()`.
* `output_graph.py`: getting rid of legacy code that does exactly what `codegen.py` does, which only handling return case but not graph break case.
* `utils.py`: added helpers to convert `numpy.ndarray` to `torch.Tensor` and vice versa. Also adding a wrapper class that takes in a function. In `__call__` it calls the function and converts its out to `torch.Tensor` (or a list of it).
* `builder.py`: add method to wrap `numpy.ndarray` graph inputs into `NumpyNdarrayVariable`, by calling `torch.numpy` in the proxy.
* `misc.py`: `numpy` API calls goes into `NumpyVariable` and we find the function with the same name in `torch_np` module, then wrap it with the wrapper defined in `utils.py`.
* `tensor.py`, `torch.py`: proxy `tensor.numpy()` to be `torch.detach()` but wrap it with `NumpyNdarrayVariable`. Similarly, `torch.from_numpy()` -> `torch.detach()` but wrap it with `TensorVariable`. In `NumpyNdarrayVariable`, do the similar `torch_np.ndarray` to `torch.Tensor` wrapping for attributes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100839
Approved by: https://github.com/ezyang
On top of #95849 this PR is trying to handle the special case when dealing with numpy.
Consider the following example:
```
def f(x: torch.Tensor) -> np.ndarray:
a = x.numpy()
return a.T
```
In previous PR this will error out because we translate `a.T` to be a method call on `torch_np.ndarray.T` which is also a `torch_np.ndarray`.
This PR handles this case, by conditionally converting a `torch_np.ndarray` to `np.ndarray` before returning, to match the original behavior.
The compiled version will be:
```
def f(x):
___tmp_0 = __compiled_fn_0(x)
if isinstance(___tmp_0, torch_np.ndarray):
return ___tmp_0.tensor.numpy()
else:
return ___tmp_0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99560
Approved by: https://github.com/jansel, https://github.com/yanboliang
Issue: #93684
# Problem
Reduce graph breaks when dynamo compiles python functions containing numpy functions and ndarray operations.
# Design (as I know it)
* Use torch_np.ndarray(a wrapper of tensor) to back a `VariableTracker`: `NumpyTensorVariable`.
* Translate all attributes and methods calls, on ndarray, to torch_np.ndarray equivalent.
This PR adds `NumpyTensorVariable` and supports:
1. tensor to ndarray, ndarray to tensor
2. numpy functions such as numpy.meshgrid()
3. ndarray attributes such as `itemsize`, `stride`
Next PR will handle returning `np.ndarray` and add support for ndarray methods
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95849
Approved by: https://github.com/ezyang
It's common to call ```dict()``` or ```collections.OrderedDict()``` inside of ```forward``` function, so we should not graph break.
This pattern has been used in many places including:
* The use case in [torchvision](
928b05cad3/torchvision/models/_utils.py (L66-L73)).
* It causes ~100 model failures(nopython=True) in the 14k github models.
* Also it hits several Meta internal use cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95250
Approved by: https://github.com/jansel
**Background:** Before this PR, support in dynamo for tensor attributes (e.g. `x.H`, `x.T`, ...) need to be individually implemented one-by-one. This could potentially lead to errors, e.g. if the implementation in [variables/tensor.py](21c7c7c72f/torch/_dynamo/variables/tensor.py (L160)) differs from the implementation from a direct call to the attribute. For attributes that were not special-cased in tensor.py, dynamo tracing would fail. This PR adds generic support for tensor attributes that return tensors without needing to specially handle them. (Notably, for x.real and x.imag, which previously weren't supported).
**In this PR:** This directly creates a proxy node for a `"call_function"` node with `target=getattr`, and feeds it into wrap_fx_proxy. This will produce a TensorVariable for the attribute returned.
This also removes the implementations for H, T, mH, mT which were broken (previously `torch.relu(x.T)` would fail). They now fall back to this default implementation (for which `torch.relu(x.T)` passes).
**Further context**:
* Ed's original suggestion in [90463](https://github.com/pytorch/pytorch/pull/90463#discussion_r1043398340) is to use `torch.Tensor.H.__get__(x)`. I wasn't able to get this to work; fx compilation fails with `getset_descriptor does not have attribute __module__`. Basically, the `__module__` attribute which is available on most python attributes, is not available on `getset_descriptor` objects. (i.e., these are implemented in C++ as attributes on torch.Tensor, so they don't obey some assumptions made by fx)
* Although both tensor attributes and methods (like `x.relu()`) both go through this, this PR should only handle attributes (e.g. see the `"getset_descriptor"` in variables/tensor.py). Methods are handled already by by GetAttrVariable.
* Prior to this PR, we already returned GetAttrVariables for unsupported attrs: the parent caller would catch the NotImplementedError and fallback to returning a GetAttrVariable. But if this GetAttrVariable was ever passed into a torch.\* function (as it could quite possibly be, since most of these attrs are tensors), it would fail because its proxy node would be missing an [example_value](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/utils.py#L1017). So: before, for some tensor x, `x.real` would work fine; but `torch.relu(x.real)` would fail.
**Testing**: added tests in test_misc.py for x.real, x.imag, x.T, x.real.T.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91840
Approved by: https://github.com/ezyang
for some tensor x, x.type(torch.FloatTensor) will essentially do the same thing as x.to(torch.float). x.type can be called with at least 3 types of inputs:
* a string "torch.FloatTensor"
* a dtype torch.float
* a tensor type torch.FloatTensor
the third option (torch.FloatTensor) fails in fx, because fx cannot trace torch.FloatTensor objects. So this PR will replace the torch.FloatTensor type with a string "torch.FloatTensor"
Why not fix this in fx? Well, it's possible, but I'm not sure a nice way to do it. We would want to update [torch.fx.node.BaseArgumentTypes](d88bc38b0c/torch/fx/node.py (L17)) to contain torch.FloatTensor etc. We could hard-code a list of tensor types there (the types vary depending on build type, e.g. whether or not cuda tensors are available), but that's not great in case our hardcoded list differs from the actual list registered by python_tensor.cpp. Another option is to dynamically populate the list of types with `Union[tuple(...)])`, and fill the tuple with `torch._tensor_classes` (which is directly populated by python_tensor.cpp), but apparently this breaks most typecheckers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93043
Approved by: https://github.com/jansel
Handle tensor default func/method args when inlining
Previously, when inlining a function, its default arguments
were only wrapped with VariableTrackers if non-tensor. Now,
tensor default args are also handled by adding them to the
parent InstructionTranslator as an attribute.
- also patches up a missing source in nnmodule call_function,
needed to properly guard on a default arg in its methods
- adds new 'DefaultsSource' type which guards either a `__defaults__`
or `__kwdefaults__` entry on a function
Fixes#90361https://github.com/pytorch/torchdynamo/issues/1968
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90575
Approved by: https://github.com/voznesenskym
Rewrite inplace addcdiv to a div, mul and inplace add to avoid graph break
Rewrite inplace add to a mul and inplace add to avoid graph break
Needed to close optimizer graph breaks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90330
Approved by: https://github.com/jansel
Fix errors from [7k github models](https://github.com/pytorch/torchdynamo/issues/1884)
```
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1062, in get_fake_value
return wrap_fake_exception(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 739, in wrap_fake_exception
return fn()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1063, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1112, in run_node
raise RuntimeError(
RuntimeError: Failed running call_function <function einsum at 0x7fd8f246a4c0>(*('i,j->ij', FakeTensor(FakeTensor(..., device='meta', size=(4,)), cpu), FakeTensor(FakeTensor(..., device='meta', size=(2,)), cuda:0)), **{}):
Unhandled FakeTensor Device Propagation for aten.mul.Tensor, found two different devices cpu, cuda:0
(scroll up for backtrace)
```
The root cause is: ```tensor.type()``` should return ```torch.cuda.FloatTensor``` rather than ```torch.FloatTensor``` if it's on GPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90021
Approved by: https://github.com/jansel
This is a group of bug fixes for [7k github models](https://github.com/pytorch/torchdynamo/issues/1884), it would fix 30+ model tests.
* Support ```tensor.type()```.
* Support ```tensor.get_device()```.
* Support ```torch.nn.functional._Reduction.get_enum```.
* Support ```torch._utils._get_device_index()```.
* Fallback ```tensor.data_ptr()```.
* ```FakeTensor``` always returns 0
* For no fake tensor propagation, we ```clone``` the input tensor, which makes no sense to track the original ```data_ptr```. And I don't think this is a very popular API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89486
Approved by: https://github.com/jansel
**Introduces symbolic shape guards into dynamo.**
In this PR, we take the existing fake tensor infra and plumbing in dynamo and we start passing a shape_env around. This shape_env does not get plumbed down to middle layers / backend yet - it only collects expressions from frontend invocations at the moment. We then translate these expressions into guards at the point where we take other guards installed throughout dynamo - and add them to check_fn.
Part 1 of https://docs.google.com/document/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit#
cc @jansel @lezcano @fdrocha @mlazos @soumith @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87570
Approved by: https://github.com/ezyang