Which inherits from `RuntimeError` and contains `error_code`, which in case of CUDA should contain error returned by `cudaGetLastError`
`torch::detail::_new_accelerator_error_object(c10::AcceleratorError&)` follows the pattern of CPython's [`PyErr_SetString`](cb8a72b301/Python/errors.c (L282)), namely
- Convert cstr into Python string with `PyUnicode_FromString`
- Create new exception object using `PyObject_CallOneArg` just like it's done in [`_PyErr_CreateException`](cb8a72b301/Python/errors.c (L32))
- Set `error_code` property using `PyObject_SetAttrString`
- decref all temporary references
Test that it works and captures CPP backtrace (in addition to CI) by running
```python
import os
os.environ['TORCH_SHOW_CPP_STACKTRACES'] = '1'
import torch
x = torch.rand(10, device="cuda")
y = torch.arange(20, device="cuda")
try:
x[y] = 2
print(x)
except torch.AcceleratorError as e:
print("Exception was raised", e.args[0])
print("Captured error code is ", e.error_code)
```
which produces following output
```
Exception was raised CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/ubuntu/pytorch/c10/cuda/CUDAException.cpp:41 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) [clone .cold] from CUDAException.cpp:0
#7 void at::native::gpu_kernel_impl<at::native::AbsFunctor<float> >(at::TensorIteratorBase&, at::native::AbsFunctor<float> const&) [clone .isra.0] from tmpxft_000191fc_00000000-6_AbsKernel.cudafe1.cpp:0
#8 at::native::abs_kernel_cuda(at::TensorIteratorBase&) from ??:0
#9 at::Tensor& at::native::unary_op_impl_with_complex_to_float_out<at::native::abs_stub_DECLARE_DISPATCH_type>(at::Tensor&, at::Tensor const&, at::native::abs_stub_DECLARE_DISPATCH_type&, bool) [clone .constprop.0] from UnaryOps.cpp:0
#10 at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA_out_abs_out(at::Tensor const&, at::Tensor&) from RegisterCUDA_0.cpp:0
#11 at::_ops::abs_out::call(at::Tensor const&, at::Tensor&) from ??:0
#12 at::native::abs(at::Tensor const&) from ??:0
#13 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd__abs>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeExplicitAutograd_0.cpp:0
#14 at::_ops::abs::redispatch(c10::DispatchKeySet, at::Tensor const&) from ??:0
#15 torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#16 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::abs>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from VariableType_1.cpp:0
#17 at::_ops::abs::call(at::Tensor const&) from ??:0
#18 at::native::isfinite(at::Tensor const&) from ??:0
#19 c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd__isfinite>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&> >, at::Tensor (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) from RegisterCompositeImplicitAutograd_0.cpp:0
#20 at::_ops::isfinite::call(at::Tensor const&) from ??:0
#21 torch::autograd::THPVariable_isfinite(_object*, _object*, _object*) from python_torch_functions_2.cpp:0
#22 PyObject_CallFunctionObjArgs from ??:0
#23 _PyObject_MakeTpCall from ??:0
#24 _PyEval_EvalFrameDefault from ??:0
#25 _PyObject_FastCallDictTstate from ??:0
#26 _PyStack_AsDict from ??:0
#27 _PyObject_MakeTpCall from ??:0
#28 _PyEval_EvalFrameDefault from ??:0
#29 _PyFunction_Vectorcall from ??:0
#30 _PyEval_EvalFrameDefault from ??:0
#31 _PyFunction_Vectorcall from ??:0
#32 _PyEval_EvalFrameDefault from ??:0
#33 _PyFunction_Vectorcall from ??:0
#34 _PyEval_EvalFrameDefault from ??:0
#35 PyFrame_GetCode from ??:0
#36 PyNumber_Xor from ??:0
#37 PyObject_Str from ??:0
#38 PyFile_WriteObject from ??:0
#39 _PyWideStringList_AsList from ??:0
#40 _PyDict_NewPresized from ??:0
#41 _PyEval_EvalFrameDefault from ??:0
#42 PyEval_EvalCode from ??:0
#43 PyEval_EvalCode from ??:0
#44 PyUnicode_Tailmatch from ??:0
#45 PyInit__collections from ??:0
#46 PyUnicode_Tailmatch from ??:0
#47 _PyRun_SimpleFileObject from ??:0
#48 _PyRun_AnyFileObject from ??:0
#49 Py_RunMain from ??:0
#50 Py_BytesMain from ??:0
#51 __libc_init_first from ??:0
#52 __libc_start_main from ??:0
#53 _start from ??:0
Captured error code is 710
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152023
Approved by: https://github.com/eqy, https://github.com/mradmila, https://github.com/ngimel
ghstack dependencies: #154436
This adds a non-blocking mode to queue_pop. This allows for workers to poll if work is ready without blocking the main loop. This is useful for the case where you want to have a GPU have maximum utilization when something only periodically is sent on the queue.
We also expose a `torch.distributed.QueueEmptyError` so users can catch the error and handle it accordingly.
Test plan:
```
pytest test/distributed/test_store.py -k queue -v -s -x
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151485
Approved by: https://github.com/fduwjj, https://github.com/tianfengfrank
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.
This:
```
try {
...
} catch (exception& e) {
// no use of e
}
```
should instead be written as
```
} catch (exception&) {
```
If the code compiles, this is safe to land.
Test Plan: Sandcastle
Reviewed By: dtolnay
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149328
Approved by: https://github.com/Skylion007, https://github.com/eqy
Introduce `conditional_gil_scoped_release` and use it in `wrap_pybind_function*` to avoid a runtime branch making the code cleaner and faster.
@albanD This is the GIL change extracted from #112607 as discussed.
Also fixes a potential use of a moved-from object introduced in #116560:
- `f` is captured by value in a lambda that may be used called times
- After `std::move(f)` the lambda is not safe to call anymore
CC @cyyever for that change
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116695
Approved by: https://github.com/albanD, https://github.com/Skylion007
We have a plethora of error types for various errors raised from c10d. These include `RuntimeError`, `TimeoutError`, `SocketError`, `DistBackendError` etc.
This results in messy code during error handling somewhat like this:
```
if "NCCL" in exception_str:
...
if "Timed out initializing process group in store based barrier on rank" in exception_str:
...
if "The client socket has timed out after" in exception_str:
...
if "Broken pipe" in exception_str:
...
if "Connection reset by peer" in exception_str:
...
```
To address this issue, in this PR I've ensured added these error types:
1. **DistError** - the base type of all distributed errors
2. **DistBackendError** - this already existed and referred to PG backend errors
3. **DistStoreError** - for errors originating from the store
4. **DistNetworkError** - for general network errors coming from the socket library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108191
Approved by: https://github.com/H-Huang