This PR adds support for non-functional collectives under `FakeTensorMode` and `fake_pg`. It helps eliminate the patching of collectives for memory and runtime estimation.
It also modifies the `ModTracker` to enable the post-backward hook call for modules whose inputs don't require gradients but parameters do.
For the memory tracking, we now enable tracking DTensor dispatcher for custom dispatch functions like `entropy_loss`.
Dispatcher is only enabled for the memory tracking part and disabled as soon as it is done.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147566
Approved by: https://github.com/weifengpy
Uses `dict.fromkeys` whenever possible as covered by flake8-comprehensions rule C420. While the ruff rule RUF025 is still in preview, flake8-comprehensions have added a new rule which covers this. Use dict.fromkeys is faster when the value being added to the dictionary is the same at every iteration and is immutable, it also removes an unnecessary dict comprehension.
This rule will be enabled with our current ruleset in RUF in 0.6 as C420.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130699
Approved by: https://github.com/lezcano, https://github.com/ezyang
For #125323
* Fixes typing for python < 3.10
* Fixes#129390
For #124688
* Improved attribution by registering `register_hook` and `post_accumulate_grad_hook` on params.
* Fixed pre-mature per module bw peak state initialization for AC.
* This improves per-module stats, global `peak_mem` was already accurate and remains unaffected.
For #128508
* When AC is applied to a `mod (nn.Module)` the backward order of execution is `pre-bw -> pre-fw -> post-fw -> post-bw`. Since the `ModTracker` maintains the `parents` attribute as set, the `post-fw` during backward was prematurely removing it from parents.
* With the fix we now maintain a per-module counter and only remove a module from `parents` when its counter goes to 0.
* Added tests to ensure this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129400
Approved by: https://github.com/awgu, https://github.com/huydhn
We present a utility MemTracker, that tracks the module-wise memory for the code executed under its context. The core features that this tool aims to provide are:
1. Capturing 'snapshots' of memory for each module during its execution. Specifically, at 8 points, during pre-forward, post-forward, pre-backward, 2nd pre-forward (if AC is applied), 2nd post-forward (if AC is applied), post-backward. Also capturing peak memory snapshot during forward and backward.
2. Each such snapshot provides the per device (cpu, cuda etc) memory breakdown in terms of the global parameters, gradients, activations, optimizer states and temporary memory.
3. A summary for each module (that can be analyzed or processed later), in terms of the memory occupied by its own parameters, buffers, inputs and outputs. The remaining components can be derived from these per module attributes and its corresponding captured snapshots.
4. Record the global peak memory consumption per device and their respective breakdowns.
5. Ability to do all of this under the FakeTensorMode so that all these statistics can be obtained without executing code on real data.
6. Ability to register and track modules, optimizers and any other tensors that are created outside the context of MemTracker.
7. Ability to capture a custom memory snapshot at any point during program execution execution.
8. Utility functions to display all of these statistics in user-friendly and human readable manner.
These features will enable users to anticipate OOMs, debug and pinpoint where majority of memory comes from, experiment with different activation checkpointing policies, batch sizes, mixed precision, model architecture features (ex. number of layers, hidden dimensions, number of attention heads etc.) and inter-device memory movement (ex. CPU off-loading) among others. Basically anything and everything related to device memory.
* __->__ #128508
Example:
> import torch
> import torchvision.models as models
> from torch.distributed._tools.mem_tracker import MemTracker
> device, dtype = "cuda", torch.float32
> with torch.device(device):
> model = models.resnet18().to(dtype=dtype)
> optim = torch.optim.Adam(model.parameters(), foreach=True)
> mem_tracker = MemTracker()
> mem_tracker.track_external(model, optim)
> with mem_tracker as mt:
> for i in range(2):
> input_batch = torch.randn(256, 3, 224, 224, device=device, dtype=dtype)
> model(input_batch).sum().backward()
> optim.step()
> optim.zero_grad()
> if i == 0:
> # to account for lazy init of optimizer state
> mt.reset_mod_stats()
> mt.display_snapshot("peak", units="MiB", tabulate=True)
> mt.display_modulewise_snapshots(depth=2, units="MiB", tabulate=True)
> # Check for accuracy of peak memory
> tracker_max = mt.get_tracker_snapshot('peak')[device]['Total']
> cuda_max = torch.cuda.max_memory_allocated()
> accuracy = tracker_max / cuda_max
> print(f"Tracker Max: {tracker_max}, CUDA Max: {cuda_max}, Accuracy: {accuracy}")
Output
<img width="1197" alt="Screenshot 2024-06-15 at 12 10 12 AM" src="https://github.com/pytorch/pytorch/assets/12934972/83e953db-43dc-4094-90eb-9f1d2ca8e758">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124688
Approved by: https://github.com/awgu