42 Commits

Author SHA1 Message Date
f414aa8e0d Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d

after:

 0 errors (1,970 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588
Approved by: https://github.com/oulgen
2025-10-03 22:03:03 +00:00
315ffdc1e4 [4/N] Apply ruff UP035 rule to python code (#164206)
Follows #164104

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164206
Approved by: https://github.com/albanD
2025-10-01 19:05:53 +00:00
cf94cadbee [CUDAGraph] Add getter for cuda graph exec (#161294)
This is far simpler than #155164 since we never destroy the cudaGraphExec_t.

The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like `cudaGraphExec*SetParams` after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear.

This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work.

Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used.

Testing:

```
pytest -s -k raw_graph_exec test/test_cuda.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161294
Approved by: https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/eqy
2025-08-25 20:57:37 +00:00
e20736bf1d Dont't GC as often when collecting cudagraphs (#158193)
TL;DR: Cuts vLLM cudagraph collection from 80s -> 24s

Stop garbage collecting by default on every cudagraph recording. The old behavior can be re-enabled by setting `TORCH_CUDAGRAPH_GC=1` or the config `force_cudagraph_gc`.

We were previously garbage collecting at the beginning of each cudagraph
capture. vLLM collects 5427 graphs and most of those garbage collections weren't
actually collecting any memory (CPU or GPU). This changes it to not collect more
than every 10s so if we're capturing in a loop we don't burn all our cycles
looking for garbage.

(These number have a lot of variance from run to run but give the correct
general scale)
```
       | calls | total | synchronize |  gcs | collect | empty cache | sys freed | cuda freed |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
before |  5427 |   78s |       1.48s | 5427 |  53.22s |       1.21s |    145855 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
after  |  5427 |   24s |          0s |    3 |   1.53s |       0.84s |       592 | 1539309568 |
-------+-------+-------+-------------+------+---------+-------------+-----------+------------+
```
total - this is the total time reported by vLLM's "Graph capturing finished" log.
The rest of these are measured in torch.cuda.graphs.graph.__enter__():
  calls - number of times torch.cuda.graphs.graph.__enter__ was called
  synchronize - this is the duration taken by the cuda.synchronize call
  gcs - number of times gc.collect was called
  collect - this is the duration taken by the gc.collect call
  empty cache - this is the duration taken by the torch.cuda.empty_cache call
  sys freed - the number of bytes reported freed by gc.collect
  cuda freed - the number of bytes reported freed by torch.cuda.memory_reserved

So it seems like the heavy lifting is done by torch.cuda.empty_cache() which is
fairly quick.

Cudagraph results from the TorchInductor Performance DashBoard (this is from the original version using the GC clock so the real results will be slightly better than this):
<img width="1494" height="382" alt="image" src="https://github.com/user-attachments/assets/69b705ef-47ce-4b6e-9733-1ec941cad93d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158193
Approved by: https://github.com/ngimel
2025-07-24 21:37:11 +00:00
250ae2531c Fix types in graphs.py (#158192)
Added type annotations for torch/cuda/graphs.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158192
Approved by: https://github.com/oulgen
2025-07-15 19:49:38 +00:00
4cc8b60d1b [BE][1/16] fix typos in torch/ (#156311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156311
Approved by: https://github.com/albanD
2025-07-09 11:02:22 +00:00
3fd84a8592 [BE][PYFMT] migrate PYFMT for torch/[a-c]*/ to ruff format (#144554)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144554
Approved by: https://github.com/soulitzer
2025-07-03 18:56:07 +00:00
9ed0060225 Provide access to the cudaGraph_t underlying a CUDAGraph. (#155164)
There are a few considerations here:

1. A user might want to modify the cudaGraph_t either during the stream capture or after the stream capture (but before instantiation). This draft implements modification after stream capture only, though support could be added for modification during stream capture by applying
https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR391-R404

2. Previously, the cudaGraph_t would be destroyed before the end of capture_end() unless the user had previously called enable_debug_mode(). There is no way to implement this correctly without removing this restriction, or forcing the user to always call enable_debug_mode(). However, enable_debug_mode() is a confusing API (despite being an instance method, it would modify a static global variable; thus, putting one CUDAGraph object into debug mode puts all of them into debug mode, which is not acceptable in my opinion). Therefore, I made enable_debug_mode() into a no-op. This means that the CPU memory usage will increase after this change. I think this is likely to be fine.

3. No python bindings yet. These should be easy to add. It is probably worthwhile to take some time to make sure that the returned cudaGraph_t can be converted into the cuda-python cudaGraph_t in a reasonable, hopefully type-safe, manner (but without making cuda-python a dependency of pytorch), since I imagine most users will use the pip cuda-python package to make modifications.

4. There are two foot guns:

   a. The cudaGraph_t returned by raw_cuda_graph() is not owned by the user, so it will be destroyed once the owning CUDAGraph is destroyed (or calls reset()).

   b. The following seuquence won't work as intended:

```
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    foo()
g.replay()
raw_graph = g.raw_cuda_graph()
modify(raw_graph)
g.replay()
```

This won't work because the user must call instantiate() again after modifying cudaGraph_t. You could add a "safety" mechanism by traversing the cudaGraph_t to create a hash and seeing if the hash changes between calls to replay(), but this is likely way too expensive.

I think these two foot guns are probably okay given that this a bit of an experts' API.

Fixes #155106

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155164
Approved by: https://github.com/ngimel
2025-06-18 03:39:28 +00:00
9a883007a2 Revert "Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)"
This reverts commit c7515da7b00de40942c83dc5856b6daec727e280.

Reverted https://github.com/pytorch/pytorch/pull/140979 on behalf of https://github.com/huydhn due to This change has been reported to break internal code ([comment](https://github.com/pytorch/pytorch/pull/140979#issuecomment-2657361940))
2025-02-13 18:04:26 +00:00
c7515da7b0 Implement cuda graphs implementation of torch.cond and torch.while_loop (#140979)
This is a new PR for #130386 , which got stale and was closed. Since I force-pushed to that branch in order to rebase it on top of main, the PR can no longer be reopened, according to https://github.com/isaacs/github/issues/361

I fixed the possibly-not-warmed-up problem described here: https://github.com/pytorch/pytorch/pull/130386/files#r1690856534

Since starting this, torch.cond and torch.while_loop now apparently have support for backward passes. I will look into what it might take to support that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140979
Approved by: https://github.com/eqy, https://github.com/eellison
2025-02-11 18:16:15 +00:00
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
5c9d5272e4 fixes #124582 (#128483)
added check for existence of outputs requiring grad to make_graphed_callables.

added new test case, updated existing test case to include parameterless modules.

Fixes #124582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128483
Approved by: https://github.com/eqy, https://github.com/ezyang
2024-07-02 08:45:59 +00:00
83bb9b7c53 [BE] explicitly export subpackage torch.utils (#128342)
Resolves #126401

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128342
Approved by: https://github.com/Skylion007
ghstack dependencies: #127707
2024-06-13 04:39:16 +00:00
62bcdc0ac9 Flip default value for mypy disallow_untyped_defs [4/11] (#127841)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127841
Approved by: https://github.com/oulgen
2024-06-08 18:36:48 +00:00
af9acc4168 Fix public binding to actually traverse modules (#126103)
The current call passes in `['/actual/path']` to os.walk which is a string pointing to no path and thus silently leads to and empty traversal.
There is an unused function just above that handles that, so I guess this is what was supposed to be called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126103
Approved by: https://github.com/suo
2024-05-15 19:36:03 +00:00
ca9678405a [CUDA graphs] Pool argument for make_graphed_callables (#121475)
It is just a nice feature to have for the situations when users want multiple graphs captures and/or graphed callables to share the same memory pool.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121475
Approved by: https://github.com/eellison, https://github.com/eqy
2024-03-09 00:15:38 +00:00
9deaa2e812 [BE]: FURB187 Use inplace reverse on lists: faster, more readable. (#121140)
Use `reverse()` method as it's faster and inplace.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121140
Approved by: https://github.com/albanD
2024-03-05 01:36:17 +00:00
46e3f670b4 refactor code to share across different devices (#120602)
# Motivation
Refactor utils code to make it possible to share across CUDA, XPU, and other backends.

# Solution
Move `_dummy_type` and `_LazySeedTracker` to torch._utils;

# Additional Context
When upstreaming, refactor these code changes by isolating them into in an additional PR to minimize their impact on the CUDA code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120602
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang
2024-02-28 09:42:58 +00:00
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45ef53747e2eefffd65d91ce840b431b.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
46712b019d Enable local_partial_types (#118467)
When using dmypy, this setting is enabled and cannot be turned off. Force it for regular mypy too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118467
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418, #118432
2024-01-28 13:38:22 +00:00
01478f1afa Fix pydocstyle errors listed in issue 112589 (#113227)
Fixes #112589

Fixed errors relating to pydocstyle in the following files. The remaining errors are related to docstrings at the module level and at methods within each module (see details below)

pydocstyle torch/cuda/_utils.py --count
before: 3
after: 0

pydocstyle torch/cuda/jiterator.py --count
before: 3
after: 1

**remaining errors:**
```
torch/cuda/jiterator.py:1 at module level:
        D100: Missing docstring in public module
```

pydocstyle torch/cuda/graphs.py --count
before: 25
after: 7

**remaining errors:**
```
torch/cuda/graphs.py:1 at module level:
        D100: Missing docstring in public module
torch/cuda/graphs.py:54 in public method `__new__`:
        D102: Missing docstring in public method
torch/cuda/graphs.py:108 in public method `debug_dump`:
        D205: 1 blank line required between summary line and description (found 0)
torch/cuda/graphs.py:108 in public method `debug_dump`:
        D400: First line should end with a period (not ':')
torch/cuda/graphs.py:150 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/graphs.py:172 in public method `__enter__`:
        D105: Missing docstring in magic method
torch/cuda/graphs.py:186 in public method `__exit__`:
        D105: Missing docstring in magic method
```

pydocstyle torch/cuda/_sanitizer.py --count
before: 35
after: 31

**remaining errors:**
```
torch/cuda/_sanitizer.py:43 in public class `AccessType`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:47 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:96 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:139 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:142 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:218 in public class `StreamSynchronizations`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:219 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:256 in public method `create_stream`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:268 in public method `create_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:272 in public method `delete_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:276 in public method `update_seq_num`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:280 in public method `record_state`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:291 in public method `stream_wait_for_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:298 in public method `all_streams_wait_for_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:307 in public method `all_streams_wait_for_stream`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:316 in public method `sync_all_streams`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:323 in public method `is_ordered_after`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:339 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:460 in public function `zip_by_key`:
        D103: Missing docstring in public function
torch/cuda/_sanitizer.py:466 in public function `zip_arguments`:
        D103: Missing docstring in public function
torch/cuda/_sanitizer.py:478 in public class `ArgumentHandler`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:479 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:505 in public method `parse_inputs`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:520 in public method `parse_outputs`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:527 in public class `CUDASanitizerDispatchMode`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:528 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:562 in public method `__torch_dispatch__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:597 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:601 in public method `enable`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:605 in public method `__del__`:
        D105: Missing docstring in magic method
```

pydocstyle torch/storage.py --count
before: 90
after: 37

**remaining errors:**
```
torch/storage.py:1 at module level:
        D100: Missing docstring in public module
torch/storage.py:310 in public class `UntypedStorage`:
        D101: Missing docstring in public class
torch/storage.py:311 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch/storage.py:317 in public method `is_cuda`:
        D102: Missing docstring in public method
torch/storage.py:321 in public method `is_hpu`:
        D102: Missing docstring in public method
torch/storage.py:325 in public method `share_memory_`:
        D102: Missing docstring in public method
torch/storage.py:444 in public class `TypedStorage`:
        D101: Missing docstring in public class
torch/storage.py:453 in public method `fill_`:
        D102: Missing docstring in public method
torch/storage.py:458 in public method `__new__`:
        D102: Missing docstring in public method
torch/storage.py:530 in public method `__init__`:
        D107: Missing docstring in __init__
torch/storage.py:599 in public method `is_cuda`:
        D102: Missing docstring in public method
torch/storage.py:604 in public method `is_hpu`:
        D102: Missing docstring in public method
torch/storage.py:624 in public method `__len__`:
        D105: Missing docstring in magic method
torch/storage.py:653 in public method `__setitem__`:
        D105: Missing docstring in magic method
torch/storage.py:681 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch/storage.py:715 in public method `copy_`:
        D102: Missing docstring in public method
torch/storage.py:723 in public method `nbytes`:
        D102: Missing docstring in public method
torch/storage.py:731 in public method `type`:
        D102: Missing docstring in public method
torch/storage.py:744 in public method `cuda`:
        D102: Missing docstring in public method
torch/storage.py:751 in public method `hpu`:
        D102: Missing docstring in public method
torch/storage.py:758 in public method `element_size`:
        D102: Missing docstring in public method
torch/storage.py:766 in public method `get_device`:
        D102: Missing docstring in public method
torch/storage.py:770 in public method `__str__`:
        D105: Missing docstring in magic method
torch/storage.py:781 in public method `__repr__`:
        D105: Missing docstring in magic method
torch/storage.py:785 in public method `__iter__`:
        D105: Missing docstring in magic method
torch/storage.py:789 in public method `__copy__`:
        D105: Missing docstring in magic method
torch/storage.py:793 in public method `__deepcopy__`:
        D105: Missing docstring in magic method
torch/storage.py:801 in public method `__sizeof__`:
        D105: Missing docstring in magic method
torch/storage.py:877 in public method `device`:
        D102: Missing docstring in public method
torch/storage.py:881 in public method `size`:
        D102: Missing docstring in public method
torch/storage.py:891 in public method `pickle_storage_type`:
        D102: Missing docstring in public method
torch/storage.py:902 in public method `__reduce__`:
        D105: Missing docstring in magic method
torch/storage.py:907 in public method `data_ptr`:
        D102: Missing docstring in public method
torch/storage.py:915 in public method `resize_`:
        D102: Missing docstring in public method
torch/storage.py:931 in public method `from_buffer`:
        D102: Missing docstring in public method
torch/storage.py:1032 in public method `from_file`:
        D402: First line should not be the function's "signature"
torch/storage.py:1075 in public method `is_shared`:
        D102: Missing docstring in public method

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113227
Approved by: https://github.com/kit1980
2023-11-13 22:05:45 +00:00
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
0a9778a372 Expose cudaStreamCaptureMode in CUDA Graphs, use local setting in inductor (#107407)
>  capture_error_mode (str, optional): specifies the cudaStreamCaptureMode for the graph capture stream.
Can be "global", "thread_local" or "relaxed". During cuda graph capture, some actions, such as cudaMalloc,
 may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for
 actions in the current thread, and "relaxed" will not error on these actions.

Inductor codegen is single-threaded, so it should be safe to enable "thread_local" for inductor's cuda graph capturing. We have seen errors when inductor cudagraphs has been used concurrently with data preprocessing in other threads.

Differential Revision: [D48656014](https://our.internmc.facebook.com/intern/diff/D48656014)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107407
Approved by: https://github.com/albanD, https://github.com/eqy
2023-08-25 01:44:26 +00:00
3bf922a6ce Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
2023-07-29 23:37:30 +00:00
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
4372dbb89f use pytree to allow any input format for cuda graph (#90941)
Summary:
1. use pytree to allow any input format for make_graphed_callables
2. add allow_unused_input argument for make_graphed_callables

Test Plan: buck2 test mode/dev-nosan  //caffe2/test:cuda --  --print-passing-details

Differential Revision: D42077976

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90941
Approved by: https://github.com/ngimel
2022-12-16 03:01:47 +00:00
eqy
62e450d55f [CUDA Graphs] Add option to dump a captured graph for debugging (#85519)
CC @xwang233 @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85519
Approved by: https://github.com/ngimel
2022-12-06 22:03:05 +00:00
ce1b727e77 Disable autocast cache in torch.cuda.make_graphed_callables (#84289)
There there are conflicts between `torch.clear_autocast_cache()` and `cudaMallocAsync` from #82682.
Moreover, the use of autocast caching is not reasonable during training which is the main target of `make_graphed_callables`.

cc @eqy @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84289
Approved by: https://github.com/ngimel
2022-09-01 21:34:51 +00:00
da0a3fe058 [Re-land] [CUDA graphs] Clear autocast amp cache (#81896)
Re-lands #81558 that got reverted due failing tests.

This failure happened because of the test that I poorly designed. [The loop here](https://github.com/pytorch/pytorch/pull/81558/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3837) is doing `cache_enabled=False` and then `cache_enabled=True`. By doing this loop the graph from previous iteration (case `False`) conflicts with the next one (case `True`). I redesigned the test such that it does not do any loops. The new test does separate function calls with different argument values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81896
Approved by: https://github.com/ngimel
2022-08-02 23:22:00 +00:00
f5b460b200 Revert "[CUDA graphs] Clear autocast amp cache (#81558)"
This reverts commit e9d07bd4f0b8d0894566dca61e39909afb2d29ec.

Reverted https://github.com/pytorch/pytorch/pull/81558 on behalf of https://github.com/janeyx99 due to Breaks windows 11.6 tests on trunk e9d07bd4f0
2022-07-21 12:46:36 +00:00
e9d07bd4f0 [CUDA graphs] Clear autocast amp cache (#81558)
According to [autocast_mode.cpp](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/autocast_mode.cpp) `cached_casts` is to be cleared at the end of each forward pass. However, this was not the case in current implementation of `make_graphed_callables` so a graph created the following way:

```
    with torch.cuda.amp.autocast(cache_enabled=True):
        graphed_foo = torch.cuda.make_graphed_callables(foo, tensors)
```
Behaves incorrectly.

cc @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81558
Approved by: https://github.com/ngimel
2022-07-21 01:44:14 +00:00
72a4f6773d Add an argument to specify warmup iterations (#78124)
Summary: Add an argument to specify the number of warmup iterations to the API ``torch.cuda.make_graphed_callables``. By default, it needs 3 warm-up iterations. To work with NCCL, it needs 11 warm-up iterations.

Differential Revision: D36606758

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78124
Approved by: https://github.com/jianyuh
2022-05-25 01:21:15 +00:00
929f1d5317 [RELAND] Adds torch.cuda.is_current_stream_capturing (#77789)
Resubmit of https://github.com/pytorch/pytorch/pull/77673, which was reverted due to Windows test failures: https://github.com/pytorch/pytorch/pull/77673#issuecomment-1130425845.

I suspect these failures happened because I don't explicitly set a side stream for graph capture in the new test.
Not setting a side stream explicitly is alright on Linux because cuda tests implicitly use a side stream.
I think Windows cuda tests implicitly use the default stream, breaking capture and leaving the backend in a bad state.
Other graphs tests explicitly set side streams and don't error in Windows builds, so i'm 95% sure doing the same for the new test will work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77789
Approved by: https://github.com/ezyang
2022-05-18 23:18:53 +00:00
0d8a0f186b Revert "Adds torch.cuda.is_current_stream_capturing (#77673)"
This reverts commit d03d43df527e48771875537ad20212d5cb333215.

Reverted https://github.com/pytorch/pytorch/pull/77673 on behalf of https://github.com/suo
2022-05-18 19:31:49 +00:00
d03d43df52 Adds torch.cuda.is_current_stream_capturing (#77673)
Exposes a way to query if CUDA graph capture is underway on the current stream.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77673
Approved by: https://github.com/ezyang
2022-05-18 16:46:35 +00:00
e3210ca184 [CUDA graphs] Beta, not prototype (#65247)
Summary:
Powers have decided this API should be listed as beta.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65247

Reviewed By: malfet

Differential Revision: D31057940

Pulled By: ngimel

fbshipit-source-id: 137b63cbd2c7409fecdc161a22135619bfc96bfa
2021-09-20 13:32:36 -07:00
8d08b103be [CUDA graphs] Prototype API and documentation (#63269)
Summary:
RFC: https://github.com/pytorch/pytorch/issues/61880

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63269

Reviewed By: mruberry

Differential Revision: D30596643

Pulled By: ngimel

fbshipit-source-id: b1f8061406364b667e2c2d4d30fbce1f0d8456be
2021-08-31 13:34:23 -07:00