Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125971
Approved by: https://github.com/albanD, https://github.com/anijain2305, https://github.com/mlazos
# Motivation
In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.
### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):
Input tensors:
```
AAAA BBB CC
AAAA BBB
BBB
```
Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```
### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):
Input tensors:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Reduce-scatter-copy-in first pad:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```
The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.
# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:
```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```
This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.
## Requirements on input
1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.
In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.
For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106645
Approved by: https://github.com/ezyang
This PR makes Z3 expressions easier to read and understand by creating a custom printer
for them.
Z3 expressions can be printed in 2 forms:
1. Using the builtin `str(e)` function
2. Using the `e.sexpr()` method
Problem is that (1) is a bit hard to read because its line breaks are not so
intuitive. (2) is a bit nicer, but the `to_int` and `to_real` functions clutter things up.
The custom printer is an improved `sexpr()` function:
- Leaves everything in one line
- Gets rid of `to_int` and `to_real` functions
- Reconstruct the floor division operations
- Merge commutative operation chains
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106643
Approved by: https://github.com/ezyang
Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.
This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function
Unit Test Added:
test_split_module_dead_code
A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output
```
Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```
After:
```
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
submod_2 = self.submod_2(x)
submod_1 = self.submod_1(x); x = None
return submod_1
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
add = x + 2; x = None
return None
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
mul = x * 2; x = None
return mul
```
Submod 2 is correctly extracted
Test Plan: Tested with new unit test
Differential Revision: D47196732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104554
Approved by: https://github.com/yf225
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104069
Approved by: https://github.com/mikaylagawarecki
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
Summary: One common cause of jit unscriptability issue is loss of node type annotations on local names after one or several FX transform(s). One way to improve the type coverage is to eagerly annotate the type for `getitem` nodes from its parent sequence node. This diff introduces an fx pass to do that.
Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```
Reviewed By: xush6528
Differential Revision: D41749744
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90237
Approved by: https://github.com/xush6528
Summary: Make `NormalizeArgs` preserve node types when transforming the graph. This bug is preventing me from scripting a graph that goes through the fx2trt `acc_tracer`.
Test Plan: New unit test
Reviewed By: ipiszy
Differential Revision: D39753021
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85637
Approved by: https://github.com/Chillee
Summary: We were handling constant attrs in a few different ways before, leading to confusion and missed handing for fused dtypes. This diff consolidates some of that code and unbreaks current breakage.
Test Plan: CI. Recently broken tests now pass.
Differential Revision: D36335238
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77401
Approved by: https://github.com/jaybean-dev, https://github.com/jamesr66a
Previously, we were taking the `.op` from OpOverload/OpOverloadPacket and looking for a mapping in `_jit_builtins` for their signature. Those will only exist for operators on the public api, not the overload packets, e.g. `torch.resize_as_` not `torch.ops.aten.resize_as_` (as least in this case, and im pretty sure generally). The OpOverloads/OpOverloadPackets have schemas stored on them so we can just use those directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77182
Approved by: https://github.com/anjali411
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py).
Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation.
I put this up as a WIP, just for discussion, but some questions off the top of my head.
1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define?
2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported.
3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback.
Some of the other implementations (for reference on requirements).
1. FX2TRT: D34868356 (internal only)
2. Edge's? @gmagogsfm
cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76469
Broken by Original commit changeset: 450e86c4e08a
Original Phabricator Diff: D35874477
Test Plan: Added unit test coverage to test_fx_experimental
Reviewed By: albanD
Differential Revision: D35978105
fbshipit-source-id: f22670b3b00a86777a26feaf4cb911595d150a17
(cherry picked from commit 91868b1e872c19d58d96a6c80a5e78dc6ffe4c7b)
This PR makes the following improvements:
- moves the custom skip list for test_normalize_operator_exhaustive in test_fx_experimental to use the typical OpInfo skip architecture. The skips were updated to xfails, and that identified some operators which were no longer failing the test
- redundant tests with OpInfo-based testing in test_jit.py were removed
- test_dtypes was improved so its error messages are clear and it makes test_nondifferentiable redundant; the latter test has been removed
- OpInfo.supports_complex_autograd() is removed in favor of a more accurate and general test for whether the particular dtype is in the backward dtypes of the operator
- gradchecks have been improved to verify that an operator doesn't support grad if it claims not to
- gradchecks have been improved to test the gradient of all input tensors that require gradient
- the concept of "default test dtypes" has been removed
- excessive and mostly redundant out testing for elementwise unary operators has been removed
- metadata for whether an op supports nuanced "safe casting" to out behavior has been removed from OpInfos
- numerous skips have been converted to xfails
- numerous OpInfos have had their metadata fixed based on the new checks
- jit-specific utilities in common_methods_invocations.py have been moved to jit_programming_utils.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75951
Approved by: https://github.com/ngimel