Summary:
Current order of implicit sharing breaks common annotation patterns of SharedQuantizationSpec, so we changed the order here.
But it's not going to work in all possible annotation cases, so quantizer implementors still need to be careful.
In general if people only refer to node/edges that comes before the current node/edge in SharedQuantizationSpec, it should work I think
Test Plan: CI, make sure this Fixed some internal tests
Differential Revision: D51605918
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114704
Approved by: https://github.com/andrewor14
While many models regress in training when converted to channels last, in inference the results are quite different. Almost all of the models experienced a speedup when converted to channels last. There were a few big regressions in torchbench - `timm_regnet` from `1.4343 → 1.0573` and `timm_resnet` from `1.7484 → 1.2868`.
I used a modified script of the operator benchmarks [here](https://gist.github.com/eellison/e11dc645412f52e8b45fb26ba6f9f6a1) to measure the average speedup of convolutions across all of the input shapes found in torchbench according to the existing classifications that @shunting314 used - grouped convs, small channel convs, convolution with larger in-channel than out-channel. Only grouped convolutions benchmarked as a slowdown in inference.
I updated the inference heuristic to multiply the flops of each conv with its predicted speedup/slowdown in channels last. With this heuristic the two previously regressing models no longer regress.
Speeds up inference for torchbench ~8% and timm ~6%. The motivating model here was SDXL which now hits channels last and improves 10%.
There were some models that were sped up in training when forcing channels last (along with a number of regressions). It's possible there is some speedup in training to be had with additional heuristics. We could also have more granular classification/predictions which might benefit both training and inference.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114600
Approved by: https://github.com/jansel, https://github.com/shunting314
As in the title.
The `bsr_dense_addmm` kernel implemented in this PR is a generalization of `bsr_dense_mm` in the following respects (in addition of having input, beta, and alpha parameters):
- it implements `SPLIT_N` kernel parameter that enables efficient kernel launches in the case of wide inputs. For instance, the timing of nn.linear with 256x256 BSR weights having 16x16 blocks and 256x131072 strided input reduced about 16x (this corresponds to the 94 % speed up value listed below).
- it supports rectangular blocks in sparse BSR tensor weights
The performance increase of nn.linear is as follows (float16, `NVIDIA A100-SXM4-80GB`):
- with 16x16 blocks, the average/maximal speed up is 55/94 %
- with 32x32 blocks, the average/maximal speed up is 33/63 %
- with 64x64 blocks, the average/maximal speed up is 23/42 %
- with 128x128 blocks, the average/maximal speed up is 15/39 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114595
Approved by: https://github.com/cpuhrsch
Quick recap of events:
(1) https://github.com/pytorch/pytorch/pull/111347, which fixed a perf regression in 2.1 compared to 2.0, introduced a correctness problem around input mutations on inputs that require grad that show up in an inference-only graph (the specific case where this can happen is rare and nobody reported the issue, but it was fixed a few weeks later)
(2) That fix happened here: https://github.com/pytorch/pytorch/pull/113584, which makes sure to keep input mutations outside of the graph, so the autograd engine can set metadata properly on them
(3) That in turn caused a slight regression compared to (1), which is what this PR attempts to fix. In particular, code like the below is safe to keep the mutations in the graph for:
```
@torch.compile
def f(x):
x.mul_(2)
x = torch.ones(2, requires_grad=True).clone()
# x requires_grad, so the input mutation will change some autograd metadata, like the version counter
# However, the mutation is under no_grad, so we don't have to worry about e.g. aliases of x having their .grad_fn fields changed
with torch.no_grad():
f(x)
```
This particular case is pretty important to the shampoo optimizer code, which is run under `torch.compile`, and mutates parameters (which require grad).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114646
Approved by: https://github.com/zou3519
Summary:
1. We stop using excess memory in generate_opcheck_tests. This is safe because
all the individual test utils already ensure that they do not modify the
inputs.
2. We re-enable the fbgemm TBE tests (see internal diff, but all of this is open
source). They were previously removed because they OOM'ed when run serially;
(1) and (3) cut down the memory usage to ~20gb peak.
3. I needed to skip some newly failing generated tests and also some that had an
impact on the memory usage.
Test Plan: - run tests
Reviewed By: sryap
Differential Revision: D51601964
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114641
Approved by: https://github.com/williamwen42
_cslt_sparse_mm + additional stride checking in test.
Summary:
This PR adds in meta registrations for _cslt_sparse_mm.
Based on the work @drisspg did
in #114370.
Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in
https://github.com/pytorch/pytorch/pull/114477.
Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114685
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
I can hold off on reviews / landing until I talk to Driss and we confirm that we need this for FP8. This PR also needs testing and probably shouldn't land until Tugsuu's input mutation handling [PR](https://github.com/pytorch/pytorch/pull/111046) goes through.
What this PR tries to solve is when you have a model that tries to mutate some nn module state (a buffer), but during the **backward**. It appears that this might be necessary for FP8's delayed scaling.
Today, AOTAutograd will just not realize if you happened to mutate any graph inputs when running the backward pass, and functionalize them away but not realize that they were input mutations. This PR tries to:
(a) detect this situation (input mutations during the backward)
(b) put `copy_()`'s in the graph to properly handle the input mutation when we can. In cases where we can't keep the copy_() in the graph, we just error loudly (I imagine that these cases will be extremely rare, but we can fix them if they ever come up).
This is mostly a prototype for now, not ready for review.
I made this example locally to test out:
```
import torch
class MutatingAutogradFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, buf):
ctx.save_for_backward(buf)
return x
@staticmethod
def backward(ctx, x_grad):
buf = ctx.saved_tensors[0]
buf.add_(x_grad)
return x_grad * 3, None
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.buf = torch.ones(2)
@torch._dynamo.allow_in_graph
def backward_mutating_fn(self, x, buf):
return MutatingAutogradFn.apply(x, buf)
def forward(self, x):
tmp = self.backward_mutating_fn(x, self.buf)
return tmp + self.buf
m = Mod()
x = torch.ones(2, requires_grad=True)
out = m(x)
# After the fw, buf should not have been mutated
print(m.buf)
out.sum().backward()
# bw has run, so buf should now be mutated
print(m.buf)
print(x.grad)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112906
Approved by: https://github.com/ezyang
Smuggle important and not too slow tests to run on this trunk job,
instead of just on the periodic job where they currently reside.
- test_dtensor_compile took 70sec, test_fsdp_2d_parallel took 198sec
locally
As a follow up, organize the distributed-mgpu tests better and maybe
rename this job to reflect its more 'general dist mgpu'
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114642
Approved by: https://github.com/wanchaol, https://github.com/malfet
I'm looking to repurpose some logic in `torch.utils.collect_env` for the `geowatch` package. I'm mostly able to just use this script as a library, which is great because it reduces code in my package. However, the issue is that the package patterns that are relevant to torch are hard-coded inside of `get_conda_packages` and `get_pip_packages`.
The changes I made are simple. I defined the default package patterns as two global sets, and I added an argument to each function that lets the user customize exactly what package patterns are relevant. If they are not specified the defaults are used.
I was considering extending the power of the patterns by utilizing `fnmatch`, `re` (or [xdev.pattern](https://github.com/Erotemic/xdev/blob/main/xdev/patterns.py) which abstracts them both), but instead I opted to just use the existing `__contains__` test to keep things simple.
From torch's perspective this should make maintaining this file slightly easier because to update relevant packages, the developer now updates two neighboring top-level globals instead of two separated local variables. However, it does add an argument to two functions, and that argument isn't used in torch itself, so there is an argument for removing that, and then users *could* still have some control by modifying globals, but I think the way I did it balances the tradeoffs well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112993
Approved by: https://github.com/zou3519
Updated version of #108885 addressing the review. In this PR:
- We add a VT.can_reconstruct utility that checks if VT.reconstruct()
does something.
- If functools.wraps(fn) is passed a `fn` that either has a source or
has .can_reconstruct() == True, then we stash the source (or the VT)
- Later on, we use the source (or VT.reconstruct) to actually
reconstruct the object in codegen.
Test Plan:
- New tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114279
Approved by: https://github.com/voznesenskym
torch.split(x, l) fails when l's shape is the unbacked symint.
E.g. l =
y.tolist() makes l the unbacked shape, because l depends on the
data access of y. The downdtream call `SliceView.create()`
evaluates the shape even if the input shape is unbacked symint,
which brings up the bug.
Test Plan:
python test/inductor/test_unbacked_symints.py -k test_split_with_sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113406
Approved by: https://github.com/aakhundov, https://github.com/ezyang
Add TORCH_NCCL_DUMP_DEBUG_INFO env to control dumping independently
of desync debug feature.
Currently default to disabled (so no behavior change by default),
but plan to default this to true after validation.
Moves 'sleep for 30 sec' that used to be after desync debug to before
it. In my view sleeping before desync is equivalent since we always
sleep the same duration, and keeps the code simpler this way.
Fixes#114433
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114614
Approved by: https://github.com/zdevito
ghstack dependencies: #114651
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
Summary:
This diff adds support in the ExecuTorch codegen layer to log the outputs of kernels to event_tracer. It does this by calling the `event_tracer_log_evalue` API.
When the `ET_EVENT_TRACER_ENABLED` flag is disabled this is essentially a no-op and will add no overhead.
Test Plan: CI
Reviewed By: larryliu0820
Differential Revision: D51534590
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114584
Approved by: https://github.com/larryliu0820
CPP Stacktrace processing (symbolizer) takes a long time on some systems
using a particular version of addr2line. In slow systems, this makes
flight-recorder dumping slow enough to time out on even toy programs.
TORCH_NCCL_TRACE_CPP_STACK=True will re-enable CPP stacktrace collection
as part of the flight recorder.
CPP stacktrace is fast enough for use on certain combinations of OS. We
can investigate moving to llvm's symbolizer as a replacement.
On devserver with C++ stacktraces disabled/enabled:
```
python test/distributed/test_c10d_nccl.py -k test_short
Ran 1 test in 12.175s
TORCH_NCCL_TRACE_CPP_STACK=1 python test/distributed/test_c10d_nccl.py -k test_short
Ran 1 test in 53.338s
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114651
Approved by: https://github.com/zdevito
This adds some unit testing for the `ignored_states` argument and auto wrapping. There is some ongoing discussion with @erhoo82 about his particular use case, but it should not block this PR. (We can land a separate PR if needed.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114612
Approved by: https://github.com/wanchaol
ghstack dependencies: #114611
With this PR it is possible to differentiate through NumPy code modulo
the usual caveats that apply to differentiation:
- That there are no graphbreaks
- That the decomposition in `torch._numpy` is differentiable
@ev-br and I were somewhat careful to achieve the second point, but
it is not tested though and through, so YMMV
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608
Approved by: https://github.com/voznesenskym
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
As in the title.
As a result, `nn.linear(<strided tensor>, <BSR tensor>, bias=<strided tensor>)` performance increases as follows (`float16`, `NVIDIA A100-SXM4-80GB`):
- 256x256 weights, speed up is 14..27 %
- 512x512 weights, speed up is 9..25 %
- 1024x1024 weights, speed up is 5..20 %
- 2048x2048 weights, speed up is 3..16 %
- 4092x4092 weights, speed up is 2..9 %
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114484
Approved by: https://github.com/cpuhrsch