We previously asked users to seperate these because we didn't have any way of adding extern C declarations. Now we don't and we don't need this confusing flag anymore
BC breaking but is fine for this API since it doesn't have major users yet. Please just put your all your code in `kernel_source` moving forward
## BC note
The header_code parameter has been removed from torch.cuda._compile_kernel. Previously, users could pass separate header code that would be prepended to the kernel source. Now, header code must be included directly in the kernel_source parameter.
Note this only affects torch.cuda._compile_kernel, which is a private API.
Example:
Before
```python
kernel = compile_kernel(
kernel_source="global void my_kernel() { ... }",
kernel_name="my_kernel",
header_code="#define SCALE 2.0f\n__device_ float scale(float x) { return x * SCALE; }"
)
```
After
```python
kernel_source = """
#define SCALE 2.0f
device float scale(float x) { return x * SCALE; }
global void my_kernel() { ... }
"""
kernel = _compile_kernel(kernel_source, "my_kernel")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163165
Approved by: https://github.com/janeyx99, https://github.com/albanD
Note that this works only in a limited case, where you *don't* change the seed, but change only the offset of the philox generator. This captures the main use case we're interested in: Rewinding the RNG to a previous state. This is done by torch.utils.checkpoint.checkpoint in particular.
Calls to increase() change only the offset, not the seed. Thus, we allow for "no-op" calls to set_seed where the new seed is the same as the old seed. If a user does happen to try to change the seed during stream capture, they will receive an error.
Fixes#162504
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162505
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/eellison, https://github.com/eee4017, https://github.com/cyyever
Per NVRTC doc - https://docs.nvidia.com/cuda/nvrtc/index.html#accessing-lowered-names, we can compile a templated kernel (e.g. `kernel<float>`) with the following steps
NVRTC side
- (new) `nvrtcAddNameExpression` -> C++ template e.g. `f<float>`
- `nvrtcCompileProgram`
- (new) `nvrtcGetLoweredName` -> get mangled name. need to do a copy since later this string is freed after NVRTC program is destroyed
- `nvrtcDestroyProgram`
CUDA side
- use mangled name instead of normal name -> profit
- `extern "C"` is not even needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162875
Approved by: https://github.com/msaroufim
## Introduction
During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it **capturing graph**) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture.
This PR adds an experimental flag `graph_capture_record_stream_reuse: True|False (default: False)`. When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path.
## Terms
* **Free marker**: A capture-legal no-op (created with `cudaGraphAddEmptyNode`) inserted after the last captured use of the block on each stream that used it.
* **Terminal**: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in `dependencies_out` by `cudaStreamGetCaptureInfo`.
## When can we reuse a block during capture?
### Strong Rule (Graph-Wide Safety)
This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph.
> A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph.
Why it's safe:
This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness.
### Per-stream Rule (A Practical Optimization)
The strong rule, while safe, is often unnecessarily restrictive. The `DeviceCachingAllocator` introduces a crucial constraint that allows for a simpler check.
In `DeviceCachingAllocator`, `get_free_block` only returns blocks whose `block->stream == p.stream()`. In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream.
> Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S.
In short, a block is considered **reusable** on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins.
## Implementation
* On `free(block)` during capture
* For each stream in `block->stream_uses` and the allocation stream, insert a free marker (empty node) and make it that stream’s tail.
* If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path.
* Otherwise, store the marker handles and keep the block in the capture-private structures.
* On `allocate(stream)` during capture (attempt per-stream reclaim)
* Query the allocation stream S’s terminal via `cudaStreamGetCaptureInfo`.
* For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal.
* If yes, hand the block to S for immediate reuse within the same capture.
* If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances.
* On capture end
* Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture.
## Examples (2 streams)
<img width="641" height="801" alt="pytorch-remove-cudagraph-defer-reclaiming (6)" src="https://github.com/user-attachments/assets/41adc835-d448-483b-99ba-b4341cb7d2a2" />
* Case 0 — Unsafe
The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails.
Counterexample intuition for the unsafe setups: imagine `f2(x)` runs for a long time. If DeviceCachingAllocator reused block `x` on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this.
* Case 1 — Reusable on stream 1
Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1.
* Case 2 — Not reusable on stream 2, but this cannot occur in `DeviceCachingAllocator`
This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: `get_free_block` rejects blocks whose `stream != p.stream()`. So this case is unreachable.
* Case 3 — Safe (strong rule holds)
In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since `DeviceCachingAllocator ` only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block.
* Case 4 — Freeing after a join
See the note below.
## Edge Case: Freeing after a join
Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's [comments here](https://github.com/pytorch/pytorch/pull/158352#pullrequestreview-3112565198)).
In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused.
## Thanks
Thanks to @galv for his great idea around graph parsing and empty nodes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158352
Approved by: https://github.com/ngimel, https://github.com/eqy
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
## Introduction
During CUDA Graph capture, the CUDA caching allocator currently defers reclaiming blocks until capture ends. This is because CUDA forbids querying events recorded during capture (the CUDA operation is not executed during the capture stage), so the allocator cannot use its normal event-based logic. However, capture records an DAG (we call it **capturing graph**) of work. We can use the capturing graph to determine when a block’s old lifetime is fully before future work, and safely reuse it within the same capture.
This PR adds an experimental flag `graph_capture_record_stream_reuse: True|False (default: False)`. When enabled, the allocator inserts lightweight free markers and uses capture ordering to decide if a freed block is safe to reuse during capture. If the proof cannot be established, we fall back to the existing post-capture path.
## Terms
* **Free marker**: A capture-legal no-op (created with `cudaGraphAddEmptyNode`) inserted after the last captured use of the block on each stream that used it.
* **Terminal**: The set of the lastest operations of the stream (or the capturing graph). Any newly captured op on that stream will attach after all nodes in this set. For a stream currently capturing, it is the set of nodes returned in `dependencies_out` by `cudaStreamGetCaptureInfo`.
## When can we reuse a block during capture?
### Strong Rule (Graph-Wide Safety)
This rule provides a universal guarantee that a block is safe for reuse by any stream in the graph.
> A block is safe to reuse if every free marker is a predecessor of every terminal of all active streams in the graph.
Why it's safe:
This rule establishes a strict global ordering. Since any new operation on any stream must be appended after that stream's terminals, this condition guarantees that the block's new lifetime begins only after its old lifetime has completely ended everywhere. This prevents lifetime overlaps when the graph is replayed, ensuring correctness.
### Per-stream Rule (A Practical Optimization)
The strong rule, while safe, is often unnecessarily restrictive. The `DeviceCachingAllocator` introduces a crucial constraint that allows for a simpler check.
In `DeviceCachingAllocator`, `get_free_block` only returns blocks whose `block->stream == p.stream()`. In other words, we never reuse a block on a stream different from the allocation stream. This means we don't need to verify safety across the entire graph. We only need to confirm that the block is safe to reuse from the perspective of its own allocation stream.
> Reuse a block for allocations on stream S if every free marker is a predecessor of every node in the terminal set of S.
In short, a block is considered **reusable** on stream S as long as all marker marking it "free" are guaranteed to complete before any new work that might need it on stream S begins.
## Implementation
* On `free(block)` during capture
* For each stream in `block->stream_uses` and the allocation stream, insert a free marker (empty node) and make it that stream’s tail.
* If we cannot place markers for all such streams (for example, a stream is not in capture), defer to the post-capture path.
* Otherwise, store the marker handles and keep the block in the capture-private structures.
* On `allocate(stream)` during capture (attempt per-stream reclaim)
* Query the allocation stream S’s terminal via `cudaStreamGetCaptureInfo`.
* For each deferred block, check whether it is allocated on this stream, and each of its free markers is a predecessor of the terminal.
* If yes, hand the block to S for immediate reuse within the same capture.
* If no, keep it deferred; it will be reconsidered as capture progresses and S’s terminal advances.
* On capture end
* Any still-deferred blocks follow the existing post-capture reclamation (event insertion/polling). External behavior remains unchanged if we cannot prove safety during capture.
## Examples (2 streams)
<img width="641" height="801" alt="pytorch-remove-cudagraph-defer-reclaiming (6)" src="https://github.com/user-attachments/assets/41adc835-d448-483b-99ba-b4341cb7d2a2" />
* Case 0 — Unsafe
The two frees are not ordered with respect to each other. For stream 1, the other stream’s free marker does not precede this stream’s terminal, so the per-stream condition fails.
Counterexample intuition for the unsafe setups: imagine `f2(x)` runs for a long time. If DeviceCachingAllocator reused block `x` on a stream whose terminal is not ordered after the free markers, the new lifetime could overlap the old one on replay, risking use-after-free or data corruption. The per-stream rule prevents exactly this.
* Case 1 — Reusable on stream 1
Stream 1’s terminal is after both frees, so every free marker precedes stream 1’s terminal. The block is reusable for allocations on stream 1.
* Case 2 — Not reusable on stream 2, but this cannot occur in `DeviceCachingAllocator`
This depicts reusing the block on stream 2 while stream 1’s free is not yet ordered before stream 2’s terminal. Though the block is not safe to reuse on stream 2, DeviceCachingAllocator will not choose that block for stream 2 anyway: `get_free_block` rejects blocks whose `stream != p.stream()`. So this case is unreachable.
* Case 3 — Safe (strong rule holds)
In this scenario, the terminal nodes of all streams are positioned after the block's free markers, satisfying the strong rule. This guarantees the block is safe for reuse by any stream in the capturing graph. However, since `DeviceCachingAllocator ` only reuses a block on its original allocation stream, verifying this strong condition is unnecessary. We only need to ensure the per-stream rule is met for the specific stream requesting the block.
* Case 4 — Freeing after a join
See the note below.
## Edge Case: Freeing after a join
Our current dependency tracking has a limitation in scenarios where a block is freed after a stream join, see @galv's [comments here](https://github.com/pytorch/pytorch/pull/158352#pullrequestreview-3112565198)).
In the case 4, we have a missed opportunity. Because the block's usage is not explicitly marked, we cannot determine that the block's actual last use may have occurred much earlier, long before the join. Then, we must wait for the subsequent join before the block can be reused.
## Thanks
Thanks to @galv for his great idea around graph parsing and empty nodes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158352
Approved by: https://github.com/ngimel
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This is far simpler than #155164 since we never destroy the cudaGraphExec_t.
The request comes from TRT-LLM specifically. The motivation is that some power users would like to mutate specific kernel parameters via APIs like `cudaGraphExec*SetParams` after a cuda graph has been instantiated. For example, a common request has been to be able to change the sequence length of attention kernels, after having captured a graph for the largest possible sequence length. It turns out that the host overhead you eliminate via cuda graphs in LLM inference ends up causing an increase in computation time when you size your kernels to the maximum possible sequence length (which I believe is done in both TRT-LLM and vLLM). Attention is the most problematic kernel because its computation time is quadratic in the sequence length, rather than linear.
This can work if your attention kernel can work for arbitrary shapes (this is not the case for all attention implementations! Many of them specialize with templates), and you have a persistent kernel that allocates only as many blocks as you have SM's (so you don't have to figure out how many blocks to allocate for a specific sequence length). Using a conditional SWITCH node is a better generic approach to this problem, but that requires more infrastructure work.
Note that this requires knowledge of the exact location of the value in your kernel's parameter buffer to mutate. It won't work with arbitrary stream capture code whose kernels you don't know before hand. So I expect this code path to be rarely used.
Testing:
```
pytest -s -k raw_graph_exec test/test_cuda.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161294
Approved by: https://github.com/ngimel, https://github.com/BoyuanFeng, https://github.com/eellison, https://github.com/eqy
# Motivation
This PR moves the implementation of `torch.cuda.memory._set_allocator_settings` to `torch._C._accelerator_setAllocatorSettings`.
Since the original API was intended as a temporary/internal utility, I am not exposing the new function as a public API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156175
Approved by: https://github.com/albanD
ghstack dependencies: #159629, #150312, #156165
# Motivation
This PR moves the implementation of `torch.cuda.memory._set_allocator_settings` to `torch._C._accelerator_setAllocatorSettings`.
Since the original API was intended as a temporary/internal utility, I am not exposing the new function as a public API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156175
Approved by: https://github.com/albanD
ghstack dependencies: #149601, #157908, #150312, #156165
Now instead of erroring out on `empty_cache` call during graph capture or under mempool context, we will just silently do nothing. This used to be the behavior for mempools, cudagraphs used to error out, but it's fine to just ignore the call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158152
Approved by: https://github.com/zou3519, https://github.com/eqy
Otherwise it turns test into a trivial one(that always succeeds), as following example demonstrates
```python
import torch
from torch.testing._internal.common_utils import serialTest, run_tests, TestCase
class MegaTest(TestCase):
@serialTest
def test_foo(self):
if hasattr(self.test_foo, "pytestmark"):
print("foo has attr and it is", self.test_foo.pytestmark)
print("foo")
@serialTest()
def test_bar(self):
if hasattr(self.test_bar, "pytestmark"):
print("bar has attr and it is", self.test_bar.pytestmark)
print("bar")
if __name__ == "__main__":
run_tests()
```
That will print
```
test_bar (__main__.MegaTest.test_bar) ... bar has attr and it is [Mark(name='serial', args=(), kwargs={})]
bar
ok
test_foo (__main__.MegaTest.test_foo) ... ok
----------------------------------------------------------------------
Ran 2 tests in 0.013s
```
Added assert that arg is boolean in the decorator to prevent such silent skips in the future
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157388
Approved by: https://github.com/clee2000
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it.
### Design Choice: Directly use algorithms name like "TF32", "BF16".
#### Pros
- The names are more informative. 'tf32' is more informative than a simple "high".
- Easier to extend new algorithm like `tf32x3`
#### Cons
- "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.
### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')

### We provide 3 fp32 compute precision can be set:
- **"ieee"**: Not allowed to use any other internal computation data types .
- **"tf32"**: Allowed to use tf32 as internal computation data types.
- **"bf16"**: Allowed to use bf16 as internal computation data types.
- **"none"**: Precision's are not set. Can be override by its father node.
### Overriding Precision Settings
Child node can be override by its father node if it is set to default.
For current default settings:
```
backend = generic, op = all, precision setting = none
backend = cuda, op = all, precision setting = none
backend = cuda, op = conv, precision setting = tf32
backend = cuda, op = rnn, precision setting = tf32
backend = cuda, op = matmul, precision setting = none
backend = matmul, op = all, precision setting = none
backend = matmul, op = conv, precision setting = none
backend = matmul, op = rnn, precision setting = none
backend = matmul, op = matmul, precision setting = none
```
- If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16".
- If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16".
### Backward Compatible
Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is
- If the user only uses previous APIs, it will work as previous expectations.
- If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user.
### Test Plan
```
python test/test_cuda.py -k test_fp32_precision_with_tf32
python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision
python test/test_cuda.py -k test_invalid_status_for_legacy_api
python test/test_mkldnn.py -k test_mlkdnn_get_set
python test/test_mkldnn.py -k test_generic_precision
python test/test_mkldnn.py -k test_invalid
python test/test_mkldnn.py -k test_default_use_parent
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888
Approved by: https://github.com/jgong5, https://github.com/albanD
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
There are a few considerations here:
1. A user might want to modify the cudaGraph_t either during the stream capture or after the stream capture (but before instantiation). This draft implements modification after stream capture only, though support could be added for modification during stream capture by applying
https://github.com/pytorch/pytorch/pull/140979/files#diff-d7302d133bb5e0890fc94de9aeea4d9d442555a3b40772c9db10edb5cf36a35cR391-R404
2. Previously, the cudaGraph_t would be destroyed before the end of capture_end() unless the user had previously called enable_debug_mode(). There is no way to implement this correctly without removing this restriction, or forcing the user to always call enable_debug_mode(). However, enable_debug_mode() is a confusing API (despite being an instance method, it would modify a static global variable; thus, putting one CUDAGraph object into debug mode puts all of them into debug mode, which is not acceptable in my opinion). Therefore, I made enable_debug_mode() into a no-op. This means that the CPU memory usage will increase after this change. I think this is likely to be fine.
3. No python bindings yet. These should be easy to add. It is probably worthwhile to take some time to make sure that the returned cudaGraph_t can be converted into the cuda-python cudaGraph_t in a reasonable, hopefully type-safe, manner (but without making cuda-python a dependency of pytorch), since I imagine most users will use the pip cuda-python package to make modifications.
4. There are two foot guns:
a. The cudaGraph_t returned by raw_cuda_graph() is not owned by the user, so it will be destroyed once the owning CUDAGraph is destroyed (or calls reset()).
b. The following seuquence won't work as intended:
```
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
foo()
g.replay()
raw_graph = g.raw_cuda_graph()
modify(raw_graph)
g.replay()
```
This won't work because the user must call instantiate() again after modifying cudaGraph_t. You could add a "safety" mechanism by traversing the cudaGraph_t to create a hash and seeing if the hash changes between calls to replay(), but this is likely way too expensive.
I think these two foot guns are probably okay given that this a bit of an experts' API.
Fixes#155106
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155164
Approved by: https://github.com/ngimel
These are created by the user passing cudaEventRecordExternal and
cudaEventWaitExternal to cudaEventRecordWithFlags() and
cudaStreamWaitEvent() respectively.
We do this by allowing the user to specify external=True when
constructing a torch.cuda.Event().
If external=False, the cudaEventRecord and cudaStreamWaitEvent API's
have a different meaning described here:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events
In short, they will be used to experess fork and join operations in
the graph if external=False.
External events can be used for expressing a fine-grained dependency
on the outcome of some nodes in a cuda graph (rather than all
nodes). They can also be used for timing parts of a cuda graph's
execution, rather than timing the entire graph's execution.
Finishes #146145
I'm a dummy and don't know how to use ghstack at this time. The first commit is a bug fix for _CudaKernel, which would previously always launch work on the NULL stream, rather than the user-passed stream.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155372
Approved by: https://github.com/ngimel
Fixes#136849
## Test Result
```python
>>> import torch
>>> device = torch.cuda.device_count() + 1
>>> torch.cuda.current_stream(device) # INTERNAL ASSERT FAILED
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/zong/code/pytorch/torch/cuda/__init__.py", line 1083, in current_stream
streamdata = torch._C._cuda_getCurrentStream(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Device index value 3 is out of index range [0, 2)
>>> torch.cuda.default_stream(device) # INTERNAL ASSERT FAILED
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/zong/code/pytorch/torch/cuda/__init__.py", line 1101, in default_stream
streamdata = torch._C._cuda_getDefaultStream(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Device index value 3 is out of index range [0, 2)
>>> torch.cuda.set_per_process_memory_fraction(0.5, device) # INTERNAL ASSERT FAILED
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/zong/code/pytorch/torch/cuda/memory.py", line 193, in set_per_process_memory_fraction
torch._C._cuda_setMemoryFraction(fraction, device)
RuntimeError: Allocator not initialized for device : did you call init?
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155318
Approved by: https://github.com/albanD
Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError`
`torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's [`PyErr_SetString`](cb8a72b301/Python/errors.c (L282)), namely
- Convert cstr into Python string with `PyUnicode_FromString`
- Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](cb8a72b301/Python/errors.c (L32))
- Set `error_code` property using `PyObject_SetAttrString`
- decref all temporary references
Test that it works and captures CPP backtrace (in addition to CI) by running
```python
import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'
import torch
x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
x[y] = 2
print(x)
except torch.AcceleratorError as e:
print("Exception was raised", e.args[0])
print("Captured error code is ", e.error_code)
```
which produces following output
```
Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
#12 at::native::abs(at::Tensor const&) from ??:0
#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#17 at::_ops::abs::call(at::Tensor const&) from ??:0
#18 at::native::isfinite(at::Tensor const&) from ??:0
#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
#22 PyObject_CallFunctionObjArgs from ??:0
#23 _PyObject_MakeTpCall from ??:0
#24 _PyEval_EvalFrameDefault from ??:0
#25 _PyObject_FastCallDictTstate from ??:0
#26 _PyStack_AsDict from ??:0
#27 _PyObject_MakeTpCall from ??:0
#28 _PyEval_EvalFrameDefault from ??:0
#29 _PyFunction_Vectorcall from ??:0
#30 _PyEval_EvalFrameDefault from ??:0
#31 _PyFunction_Vectorcall from ??:0
#32 _PyEval_EvalFrameDefault from ??:0
#33 _PyFunction_Vectorcall from ??:0
#34 _PyEval_EvalFrameDefault from ??:0
#35 PyFrame_GetCode from ??:0
#36 PyNumber_Xor from ??:0
#37 PyObject_Str from ??:0
#38 PyFile_WriteObject from ??:0
#39 _PyWideStringList_AsList from ??:0
#40 _PyDict_NewPresized from ??:0
#41 _PyEval_EvalFrameDefault from ??:0
#42 PyEval_EvalCode from ??:0
#43 PyEval_EvalCode from ??:0
#44 PyUnicode_Tailmatch from ??:0
#45 PyInit__collections from ??:0
#46 PyUnicode_Tailmatch from ??:0
#47 _PyRun_SimpleFileObject from ??:0
#48 _PyRun_AnyFileObject from ??:0
#49 Py_RunMain from ??:0
#50 Py_BytesMain from ??:0
#51 __libc_init_first from ??:0
#52 __libc_start_main from ??:0
#53 _start from ??:0
Captured error code is 710
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152023
Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel
ghstack dependencies: #154436
Removes MemPoolContext from custom user mempools. The ground truth for which pool should be used is in graph_pools active pool, and MemPoolContext just introduced an opportunity for the pool pointed to by MemPoolContext and active pool in graph_pools to go out of sync (see all the asserts in the code to make sure that happens, and yet it still could happen in a multithread scenario, see my recent PRs (#153990).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154042
Approved by: https://github.com/albanD, https://github.com/syed-ahmed
Fixes#152008
This PR fixes a segmentation fault that occurred when exiting the program due to improper background thread management in CachingHostAllocator.
Previously, the background thread continued running and called process_events() even after the allocator object was destroyed, leading to a crash on exit.
f12d8d60b1/aten/src/ATen/core/CachingHostAllocator.h (L218)
```cpp
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
getBackgroundThreadPool()->run([&]() {
while (true) {
process_events(); // <-- This line may cause segfault on exit
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
});
return true;
}();
```
The fix adds a mechanism to signal the background thread to exit before the object is destructed, ensuring the thread stops safely.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154117
Approved by: https://github.com/ngimel, https://github.com/cyyever