Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66048
Previously, create_arg would fail if it encountered a not `None` layout argument. Adding it to `BaseArgumentTypes` list should be enough to fix that.
Test Plan: Added unittest
Reviewed By: jamesr66a
Differential Revision: D31362662
fbshipit-source-id: 20049971e18c17e9c75e50540500c567266daa55
Summary:
Previously resulted in `AttributeError: module 'operator' has no attribute 'and'`
and/or are python keywords, so they are renamed to `operator.and_` and `operator.or_`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65196
Reviewed By: Chillee
Differential Revision: D31020336
Pulled By: jansel
fbshipit-source-id: 51d888151fe78c0c1197ecaf161976b219c59694
Summary:
OpInfo tracker: https://github.com/pytorch/pytorch/issues/54261
- Eliminate duplicated testing logic in test_autograd
- Moved tests that rely on this testing logic to use OpInfos
- `cat` already has OpInfo (no action needed)
- Created OpInfo for `block_diag` and `broadcast_tensors`
Running into some FX errors. Added op to skip-list and created an issue here: https://github.com/pytorch/pytorch/issues/64997
Both `block_diag` and `broadcast_tensors` are variadic, so skipping `test_variant_consistency_jit` (from comments on other OpInfos, it looks like JIT does not support variadic tensors)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64993
Reviewed By: jbschlosser
Differential Revision: D30961736
Pulled By: soulitzer
fbshipit-source-id: e169305384a683acae1178c4e12e9e214a67226a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64271Closes#60417
Modified emit_node() in fx/graph.py to generate getattr() call with default value when len(node.args) != 2 instead of accessing the attribute.
Added test_torch_fx_getattr() in test/test_fx.py.
Test Plan:
pytest test/test_fx.py
Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D30671265
fbshipit-source-id: f2db9ea47e0cb247547e200684f715aab006c374
Summary:
Implements feature request https://github.com/pytorch/pytorch/issues/62021
Test it out with
```python
from torch import fx
from torch import nn
def fx_int(x):
return int(x)
class MyModule(nn.Module):
def forward(self, x):
return fx_int(x.shape[0] / 2)
tracer = fx.Tracer(autowrap_functions=(fx_int,)) # or remove kwarg to demonstrate symbolic trace error
tracer.trace(MyModule())
```
First time contributor, so please advise if I could have done anything to make lives easier for next time.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62106
Reviewed By: SplitInfinity, driazati
Differential Revision: D30080834
Pulled By: jamesr66a
fbshipit-source-id: 68fadf8c881ea7930e7afd62b642874010fe4903
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62436
## Problem
Given two modules and a tracer that indiscriminately marks all modules as a leaf:
```
class InnerModule(torch.nn.Module):
def forward(self, t):
return t + t
class MyModule(torch.nn.Module):
def __init__(self, inner):
super().__init__()
self.inner = inner
def forward(self, t):
x = self.inner(t)
y = self.inner(t)
return x + y
class MyTracer(torch.fx.Tracer):
def is_leaf_module(self, module, name):
return True
```
One might generally expect the following behavior (note call_module nodes):
```
print(">> Outer GraphModule (with inner module as nn.Module):")
inner = InnerModule()
m = MyModule(inner)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Outer GraphModule (with inner module as nn.Module):
opcode name target args kwargs
------------- ------- ----------------------- ---------------- --------
placeholder t t () {}
call_module inner inner (t,) {}
call_module inner_1 inner (t,) {}
call_function add <built-in function add> (inner, inner_1) {}
output output output (add,) {}
None
```
However, when the inner module is first symbolically traced, the symbolic trace of the outer module ignores `is_leaf_node` entirely, and traces through the whole module (note call_function nodes).
```
print(">> Inner module as GraphModule:")
inner = InnerModule()
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
print(inner_gm.graph.print_tabular())
print(">> Outer GraphModule (with inner module as GraphModule):")
m = MyModule(inner_gm)
gm = torch.fx.GraphModule(m, MyTracer().trace(m))
print(gm.graph.print_tabular())
>> Inner module as GraphModule:
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
output output output (add,) {}
None
>> Outer GraphModule (with inner module as GraphModule):
opcode name target args kwargs
------------- ------ ----------------------- ------------ --------
placeholder t t () {}
call_function add <built-in function add> (t, t) {}
call_function add_1 <built-in function add> (t, t) {}
call_function add_2 <built-in function add> (add, add_1) {}
output output output (add_2,) {}
None
```
This is surprising behavior and at first glance violates the tracer's intent. As I understand it, `torch.fx.symbolic_trace.Tracer.trace` intends to patch `torch.nn.Module.__call__` with a `module_call_wrapper()` that records a `call_module` node if the module is a leaf, else executes `torch.fx._symbbolic_trace._orig_module_call = torch.nn.Module.__call__`, which is set a module loading time.
**Every submodule should be a leaf, but no `call_module` nodes are created when that submodule is a `GraphModule`. Why?**
Upon further inspection, I found:
- The constructor for GraphModule includes a path to `GraphModule.recompile()` via the setter for a `fx.Graph`:
```
inner_gm = torch.fx.GraphModule(inner, MyTracer().trace(inner))
File "/torch/fx/graph_module.py", line 252, in __init__
self.graph = graph
File "/torch/nn/modules/module.py", line 1183, in __setattr__
object.__setattr__(self, name, value)
File "/torch/fx/graph_module.py", line 277, in graph
self.recompile()
```
- `recompile()` wraps the `__call__` method by holding a reference to the `__call__` method at the time of recompilation:
```
cls = type(self)
cls_call = cls.__call__
...
def wrapped_call(self, *args, **kwargs):
try:
return cls_call(self, *args, **kwargs)
except Exception as e:
...
cls.__call__ = wrapped_call
```
- Recompilation of the inner GraphModule happens on initialization, before creation or tracing of the outer module. Adding some old-fashioned print debug statements gives:
```
Inner Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
recompile: cls.__call__ now wraps _orig_module_call, <function Module._call_impl at 0x7faaebfee8b0>
Outer Module:
_orig_module_call: <function Module._call_impl at 0x7faaebfee8b0>
tracing: patching method <class 'torch.nn.modules.module.Module'>.__call__ <function Module._call_impl at 0x7faaebfee8b0> with <function Module._call_impl at 0x7fa9d42bce50>
outer module MRO before tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
outer module MRO during tracing:
(0) <class '__main__.MyModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(1) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO before tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7faaebfee8b0>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7faaebfee8b0>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
inner module MRO during tracing:
(0) <class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>: <function x.y.z.wrapped_call at 0x7fa9d42a8670>
(1) <class 'torch.fx.graph_module.GraphModule'>: <function Module._call_impl at 0x7fa9d42bce50>
(2) <class 'torch.nn.modules.module.Module'>: <function Module._call_impl at 0x7fa9d42bce50>
(3) <class 'object'>: <method-wrapper '__call__' of type object at 0x7fac3cd15f00>
```
- The outer module is patched correctly, but the inner module's first element in its MRO is the `wrapped_call` from `recompile` that still invokes `<function Module._call_impl at 0x7faaebfee8b0>` directly. Therefore, no call_module nodes are created.
## In Practice
In practice, this behavior affects the ability of `torch.package` to package `GraphModules` whose submodules are `GraphModules`. In our case, the `GraphModule` submodules are not passed through a constructor, but created separately and installed on the root `GraphModule` via `setattr`. This means that prior to packaging, there appear to be no issues with the module, since the root's graph was created before any call_module targets were replaced with `GraphModules`.
When unpackaging such a model with `torch.package`, `torch.fx.graph_module._deserialize_graph_module` uses an inline `KeepModules` tracer that sets all submodules to leaves; the unpackaged module is implicitly and surprisingly inlined in the process.
## Potential Solution
This behavior was previously not understood by us, and so the current workaround is a gnarly process of wrapping all submodules with a `nn.Module` with a manually installed forward method.
Changing `wrapped_call` to return `return super(type(self), self).__call__(*args, **kwargs)` whenever `__call__` is inherited at least appears to solve the issue. Does this seem like an acceptable approach?
## Other Thoughts
- Repeated calls to `recompile` create nested `wrapped_calls`, all for the purpose of error handling. This seems probably unnecessary ¯\\_(ツ)\_/¯
- If a root module with a overriden `__call__` method is symbolically traced, it is ignored
Test Plan:
```
buck test:
✓ ListingSuccess: caffe2/test:fx - main (12.570)
✓ Pass: caffe2/test:fx - test_tracing_graphmodules_as_leaf_submodules (test_fx.TestFX) (11.982)
```
Reviewed By: ansley
Differential Revision: D29997935
fbshipit-source-id: 1988fbb025b14188da26a3e73e94fb789c3c1f74
Summary:
Fixes https://github.com/pytorch/pytorch/issues/61733
Allow FX tracer to trace control flow (if/while) statements when parameter shapes are in the condition.
If the user specifies the new "param_shapes_constant" option when constructing a tracer, the model's parameter shape attribute will be evaluated and the resulting constant will be emitted into the IR during tracing.
Also added a new test
`
python test/fx/test_fx_param_shape_control_flow.py
`
The test also performs a somewhat whitebox style testing to check the generated Python code from the IR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61820
Reviewed By: bdhirsh
Differential Revision: D29969299
Pulled By: jerryzhenleicai
fbshipit-source-id: 99aae824bdfec880be69258de7ead5c8cd59eddc
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292
This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.
Test Plan
- new tests in test/test_pytree.py and test/test_fx.py
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D29947302
Pulled By: zou3519
fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
Summary:
### Issue
Build PyTorch wheel packages during build stage for pull requests and install during test stage.
### Fix
Update all tests which call lib*.so (under `./build folder`), change to call lib*.so in `{ent}/pytorch/lib/python3.8/site-packages/torch`
### Diff
This diff starts to update test_fx, test_backend and test_torchbind first to check if current ci pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61960
Test Plan: check of all ci workflows pass
Reviewed By: malfet, saketh-are
Differential Revision: D29823235
Pulled By: tktrungna
fbshipit-source-id: e7f652def698e303d4843fbaedf4859f5eca2fd9
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61780
These changes would allow objects to control how they are handled when they are an argument to a torch.fx call_module node from within their source. Previously, we have been using a custom Tracer with an overridden create_arg() method and branching based on class name to handle args that are unusual (data classes, etc).
Reviewed By: suo, houseroad
Differential Revision: D27976120
fbshipit-source-id: 0c5249c5f8398368ca0fbec0ad8a07ccf99b7da4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61463
seems like a small oversight(?), current test fails when warnings are recorded. discovered this when calling `graph.call_module(existing_call_module_node.target)` and it raised a warning
Test Plan: `buck test //caffe2/test:fx`
Reviewed By: ansley
Differential Revision: D29637799
fbshipit-source-id: 2305629863230235f76a926fe2e4de480cbf853c
Summary:
Reference https://github.com/pytorch/pytorch/issues/50345
`zeta` was already present in the codebase to support computation of `polygamma`.
However, `zeta` only had `double(double, double)` signature **for CPU** before the PR (which meant that computation `polygamma` were always upcasted to `double` for zeta part).
With this PR, float computations will take place in float and double in double.
Have also refactored the code and moved the duplicate code from `Math.cuh` to `Math.h`
**Note**: For scipy, q is optional, and if it is `None`, it defaults `1` which corresponds to Reimann-Zeta. However, for `torch.specia.zeta`, I made it mandatory cause for me it feels odd without `q` this is Reimann-Zeta and with `q` it is the general Hurwitz Zeta. I think sticking to just general made more sense as passing `1` for q sounds trivial.
Verify:
* [x] Docs https://14234587-65600975-gh.circle-artifacts.com/0/docs/special.html#torch.special.zeta
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59623
Reviewed By: ngimel
Differential Revision: D29348269
Pulled By: mruberry
fbshipit-source-id: a3f9ebe1f7724dbe66de2b391afb9da1cfc3e4bb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60057
This ensures that if a function was `wrap`'d before symbolic tracing + being passed into the transformer then it will still be wrapped.
Test Plan: Added test to `test_fx.py`
Reviewed By: jamesr66a
Differential Revision: D29151191
fbshipit-source-id: 93560be59505bdcfe8d4f013e21d4719788afd59