A rewrite of #138964
In addition to rewriting the conditions for using copy2d, this PR fixes a few other problems with #138964:
1) gpu-gpu copies when peer access is disabled shouldn't rely on copy2d
2) copy2d should record even for the host pinned memory, like the regular copy does
3) copy2d shouldn't pretend that it's synchronizing (for the purposes of cuda sanitizer tracer) when it's non-blocking
In this PR copy2d behaves in exactly the same way as copy does wrt to those additional syncs, except it calls a different underlying cuda call.
Tests for multiple cases going through copy2d and avoiding copy2d pattern due to unsatisfied conditions are added.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146256
Approved by: https://github.com/eqy, https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
### **Pull Request: Optimized Non-Contiguous Tensor Copy for CPU to GPU in PyTorch**
#### **Summary**
This PR addresses the performance issue identified in [#111570](https://github.com/pytorch/pytorch/issues/111570), where non-contiguous tensors took significantly longer to transfer from CPU to GPU. Through detailed tracing of the call flow, we identified that PyTorch was creating temporary contiguous buffers for non-contiguous tensor transfers, which introduced unnecessary overhead.
#### **Tracing the Issue**
To pinpoint the cause of the slowdown, we followed the call flow from Python’s `tensor.cuda()` method through PyTorch’s backend, ultimately identifying `copy_kernel_cuda` as the key function responsible for CPU-to-GPU tensor transfers. Here’s a summary of the tracing process:
1. **Python Call: `tensor.cuda()`**
- Starting from Python, the `cuda()` method initiates the tensor transfer to the GPU.
2. **`TensorBody.h: cuda()`**
- The `cuda()` method calls `to()`, specifying the target device as CUDA.
3. **`Tensor.cpp: TensorBase::to()`**
- The `to()` function prepares device and data type options before invoking `_ops::to_dtype_layout::call()`.
4. **Operator Call: `_ops::to_dtype_layout::call()`**
- This operator dispatches the request to the backend-specific function responsible for managing the transfer.
5. **`Copy.cpp: copy_()`**
- The `copy_()` function performs preliminary checks (e.g., zero-tensor immutability) and proceeds to call `copy_impl()`.
6. **`Copy.cpp: copy_impl()`**
- This function sets up a tensor iterator and dispatches the copy operation to the appropriate backend through `copy_stub`.
7. **Dispatch to CUDA: `copy_stub`**
- The dispatch mechanism routes the call to the CUDA-specific function, `copy_kernel_cuda`.
8. **`Copy.cu: copy_kernel_cuda()`**
- Here, we identified that PyTorch was creating temporary contiguous buffers for 1D and 2D non-contiguous tensors, which slowed down the copy process. This behavior is managed by the `copy_requires_temporaries()` function.
#### **Solution**
To address this, we modified `copy_kernel_cuda` to handle non-contiguous 1D and 2D tensors directly by using `cudaMemcpy2DAsync`, which allows efficient, stride-aware memory transfers without temporary buffers. Here’s why this approach improves performance:
- **Efficiency of `cudaMemcpy2DAsync`**: This CUDA function is optimized for pitched (stride-based) memory transfers, allowing it to handle non-contiguous data layouts effectively by specifying memory strides for source and destination tensors.
- **Reduction of Overhead**: By directly copying non-contiguous tensors without intermediate buffers, we eliminate extra memory allocation and achieve faster CPU-to-GPU transfers.
- **Asynchronous Execution**: `cudaMemcpy2DAsync` enables asynchronous transfer on the CUDA stream, further improving performance by taking advantage of CUDA's optimized memory handling for non-contiguous layouts.
#### **Performance Results**
In my testing, I created tensors of size `327680 x 2000` and used slices for transfer performance measurements. The tests show that the average time for transferring a non-contiguous slice (e.g., rows 10,000 to 50,000) from CPU to GPU now closely matches the contiguous case. This improvement indicates that the updated implementation effectively addresses the performance discrepancy. Below are the measured times and validation checks:
```plaintext
Average time for contiguous slice (rows 10,000-50,000): 66 ms
Average time for non-contiguous slice (rows 10,000-50,000): 66 ms
Validation of contiguous and non-contiguous tensor copies:
✅ PASS: Tensor shapes match.
✅ PASS: Tensor contiguity matches.
✅ PASS: Tensor contents match.
✅ PASS: Tensor data types match.
✅ Success: Both contiguous and non-contiguous tensors were copied correctly to the GPU.
```
#### **Conclusion**
This PR resolves the identified performance issue by eliminating the need for temporary buffers in non-contiguous 1D and 2D tensor transfers, ensuring faster and more efficient copies from CPU to GPU. Future optimizations could further enhance performance for higher-dimensional non-contiguous tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138964
Approved by: https://github.com/jeffdaily
Co-authored-by: Natalia Gimelshein <ngimel@gmail.com>
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This method has to be accessible from `c10` to enable CUDA-12 integration.
Implemented by providing private `c10::cuda:_internal::setHasPrimaryContext` that passes the pointer to the implementation (in `torch_cuda`) back to c10.
Use global class constructor/destructor to guarantee RAII.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96800
Approved by: https://github.com/ngimel
Summary:
- HIP_VERSION semantic versioning will change in ROCm4.3. The changes essentially remove the dependency on HIP_VERSION provided in the hip header to keep code compatible with older and newer versions of ROCm.
- TORCH_HIP_VERSION is derived from HIP_VERSION_MAJOR and HIP_VERSION_MINOR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62786
Reviewed By: bdhirsh
Differential Revision: D30281682
Pulled By: seemethere
fbshipit-source-id: e41e69fb9e13de5ddd1af99ba5bbdcbb7b64b673
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
Summary:
This is a first step towards creating context manager that errors out on synchronizing calls.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61889
Reviewed By: albanD
Differential Revision: D29805280
Pulled By: ngimel
fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
Summary:
After the change async error warnings look as follows:
```
$ python -c "import torch;torch.eye(3,3,device='cuda:777')"
Traceback (most recent call last):
File "<string>", line 1, in <module>
RuntimeError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59467
Reviewed By: ngimel
Differential Revision: D28904360
Pulled By: malfet
fbshipit-source-id: 2a8fa5affed5b4ffcaa602c8ab2669061cde7db0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57609
Throw c10::CudaError for CUDA Exceptions for better classification of errors
Test Plan: Test locally by running some workflows
Reviewed By: dzhulgakov
Differential Revision: D28209356
fbshipit-source-id: 19a5fc8548433238dc224ea81a5f63a945fc5cc3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42249
Main change is to bring Caffe2's superior error messages for cuda initialization into c10 and use them in all code paths.
Basic logic:
| Case | Call to device_count() | init_cuda, e.g. allocating tensor |
| -- | -- | -- |
| all good | non-zero | just works |
| no gpus | 0, no warning | throw exception with good message |
| driver issues | 0, produce warning | throw exception with good message |
| out of memory with ASAN | 0, produce warning| throw exception with ASAN message |
Previously, the error thrown from init_cuda was very generic and the ASAN warning (if any) was buried in the logs.
Other clean up changes:
* cache device_count() always in a static variable
* move all asan macros in c10
Test Plan:
Hard to unittest because of build modes. Verified manually that the behavior from the table above holds by running the following script in different modes (ASAN/no-ASAN, CUDA_VISIBLE_DEVICES=):
```
print('before import')
import torch
print('after import')
print('devices: ', torch.cuda.device_count())
x = torch.tensor([1,2,3])
print('tensor creation')
x = x.cuda()
print('moved to cuda')
```
Reviewed By: ngimel
Differential Revision: D22824329
fbshipit-source-id: 5314007313a3897fc955b02f8b21b661ae35fdf5
Summary:
* Make c10::cuda functions regular non-inlined functions
* Add driver_version() and device_synchronize() functions
With this change I don't see anymore direct calls to CUDA API when look at Modules.cpp.obj
FYI malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42251
Reviewed By: malfet
Differential Revision: D22826505
Pulled By: ziab
fbshipit-source-id: 8dc2f3e209d3710e2ce78411982a10e8c727573c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18485
I don't know how (1) we landed the wrong version of the patch and (2) how
this passed the push blocking test
Reviewed By: pjh5
Differential Revision: D14621961
fbshipit-source-id: 0a3953d7adcdc79727a61c2acff65f436dcafe55
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18445
ghimport-source-id: 30d018737bf6989bc68b7e3676f44e0ca6141fde
Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18242 Test running a CUDA build on CPU machine.
* **#18445 Unify cudaGetDeviceCount implementations.**
I went about doing this by searching for calls to cudaGetDeviceCount,
and then methodically replacing them with references to c10::cuda::device_count()
or at::cuda::device_count().
There is a point to doing this: the various implementations wildly differed
in their handling of what to do when cudaGetDeviceCount returns an error.
The final standardized behavior is that **all errors are swallowed** and
we return device count of zero. This indirectly fixes running CUDA builds
on CPU, which was broken in #17847.
I added 'noexcept' to the 'deviceCount' virtual method on DeviceGuardImpl.
This is a BC-breaking change for anyone inheriting from DeviceGuardImpl
but all you need to do is put 'noexcept' on your method and it is backwards
compatible with older libtorch.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14612189
fbshipit-source-id: 3c8d186e3dd623c0e27625212c7ce30f75d943cb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15692
It was leading to ocassional crashes with dynamically linked CUDA because runtime was already destroyed.
Also, unique_ptr<T[]> is more suitable than deque<T> for the purpose.
Reviewed By: Yangqing
Differential Revision: D13571988
fbshipit-source-id: 37eb26dfbe361c49160367b53f87bd037c6c0e46
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15316
This starts cleaning up the files in c10 according to the module structure we decided on.
Move to c10/util:
- Half.h, Half-inl.h, Half.cpp, bitcasts.h
Move to c10/core:
- Device.h, Device.cpp
- DeviceType.h, DeviceType.cpp
i-am-not-moving-c2-to-c10
Reviewed By: dzhulgakov
Differential Revision: D13498493
fbshipit-source-id: dfcf1c490474a12ab950c72ca686b8ad86428f63
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13920
I want to move CUDAStream and CUDAGuard to c10_cuda without also
bringing along CUDAContext or CUDAEvent for the ride (at least for
now). To do this, I need to eliminate those dependencies.
There's a few functions in CUDAContext.h which don't really need
THCState, so they're separated out and put in general
purpose c10/cuda/CUDAFunctions.h
Reviewed By: smessmer
Differential Revision: D13047468
fbshipit-source-id: 7ed9d5e660f95805ab39d7af25892327edae050e