Define the SYCL device version assertation used in the other backend, like XPU (#84106)

# Motivation:
We need a device version assertation that can be used in SYCL kernel. SYCL_KERNEL_ASSERT will be used in the kernel launched on device XPU.

# Solution:
We add a macro SYCL_KERNEL_ASSERT via __assert_fail declaration for Linux and _wassert declaration for Windows even though  NDEBUG is enabled.

# Additional context:
`__assert_fail` in SYCL kernel
`extern SYCL_EXTERNAL void __assert_fail(const char *expr, const char *file, unsigned int line, const char *func);`
`_wassert` in SYCL kernel
`extern SYCL_EXTERNAL void _wassert(const wchar_t *wexpr, const wchar_t *wfile, unsigned line);`
No additional unit test because this change could not affect PyTorch's functionality. It only affects assertation in kernel on XPU backend. So it is difficult to add ut to test it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84106
Approved by: https://github.com/malfet
This commit is contained in:
Yu, Guangye
2022-09-01 22:22:25 +00:00
committed by PyTorch MergeBot
parent 1463c6f3de
commit ff56f1c30d

View File

@ -331,21 +331,33 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
(defined(USE_ROCM) && defined(ROCM_DISABLE_GPU_ASSERTS))
// Those platforms do not support assert()
#define CUDA_KERNEL_ASSERT(cond)
#define SYCL_KERNEL_ASSERT(cond)
#elif defined(_MSC_VER)
#if defined(NDEBUG)
extern "C" {
C10_IMPORT
#if defined(__SYCL_DEVICE_ONLY__)
extern SYCL_EXTERNAL void _wassert(
const wchar_t* wexpr,
const wchar_t* wfile,
unsigned line);
#else
#if defined(__CUDA_ARCH__)
__host__ __device__
#endif // __CUDA_ARCH__
void
_wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line);
}
#endif
#endif // __SYCL_DEVICE_ONLY__
#endif // NDEBUG
#define CUDA_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
(void)(_wassert(_CRT_WIDE(#cond), _CRT_WIDE(__FILE__), static_cast<unsigned>(__LINE__)), 0); \
}
#define SYCL_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
(void)(_wassert(_CRT_WIDE(#cond), _CRT_WIDE(__FILE__), static_cast<unsigned>(__LINE__)), 0); \
}
#else // __APPLE__, _MSC_VER
#if defined(NDEBUG)
extern "C" {
@ -390,6 +402,11 @@ __device__ __attribute__((noinline)) __attribute__((weak)) void __assert_fail(
__assert_fail( \
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#define SYCL_KERNEL_ASSERT(cond) \
if (C10_UNLIKELY(!(cond))) { \
__assert_fail( \
#cond, __FILE__, static_cast<unsigned int>(__LINE__), __func__); \
}
#endif // __APPLE__
#ifdef __APPLE__