Introduce CUDA Device Assertions Infrastructure (#84609)

Summary:
This diff introduces a set of changes that makes it possible for the host to get assertions from CUDA devices. This includes the introduction of

**`CUDA_KERNEL_ASSERT2`**

A preprocessor macro to be used within a CUDA kernel that, upon an assertion failure, writes the assertion message, file, line number, and possibly other information to UVM (Managed memory). Once this is done, the original assertion is triggered, which places the GPU in a Bad State requiring recovery. In my tests, data written to UVM appears there before the GPU reaches the Bad State and is still accessible from the host after the GPU is in this state.

Messages are written to a multi-message buffer which can, in theory, hold many assertion failures. I've done this as a precaution in case there are several, but I don't actually know whether that is possible and a simpler design which holds only a single message may well be all that is necessary.

**`TORCH_DSA_KERNEL_ARGS`**

This preprocess macro is added as an _argument_ to a kernel function's signature. It expands to supply the standardized names of all the arguments needed by `C10_CUDA_COMMUNICATING_KERNEL_ASSERTION` to handle device-side assertions. This includes, eg, the name of the pointer to the UVM memory the assertion would be written to. This macro abstracts the arguments so there is a single point of change if the system needs to be modified.

**`c10::cuda::get_global_cuda_kernel_launch_registry()`**

This host-side function returns a singleton object that manages the host's part of the device-side assertions. Upon allocation, the singleton allocates sufficient UVM (Managed) memory to hold information about several device-side assertion failures. The singleton also provides methods for getting the current traceback (used to identify when a kernel was launched). To avoid consuming all the host's memory the singleton stores launches in a circular buffer; a unique "generation number" is used to ensure that kernel launch failures map to their actual launch points (in the case that the circular buffer wraps before the failure is detected).

**`TORCH_DSA_KERNEL_LAUNCH`**

This host-side preprocessor macro replaces the standard
```
kernel_name<<<blocks, threads, shmem, stream>>>(args)
```
invocation with
```
TORCH_DSA_KERNEL_LAUNCH(blocks, threads, shmem, stream, args);
```
Internally, it fetches the UVM (Managed) pointer and generation number from the singleton and append these to the standard argument list. It also checks to ensure the kernel launches correctly. This abstraction on kernel launches can be modified to provide additional safety/logging.

**`c10::cuda::c10_retrieve_device_side_assertion_info`**
This host-side function checks, when called, that no kernel assertions have occurred. If one has. It then raises an exception with:
1. Information (file, line number) of what kernel was launched.
2. Information (file, line number, message) about the device-side assertion
3. Information (file, line number) about where the failure was detected.

**Checking for device-side assertions**

Device-side assertions are most likely to be noticed by the host when a CUDA API call such as `cudaDeviceSynchronize` is made and fails with a `cudaError_t` indicating
> CUDA error: device-side assert triggered CUDA kernel errors

Therefore, we rewrite `C10_CUDA_CHECK()` to include a call to `c10_retrieve_device_side_assertion_info()`. To make the code cleaner, most of the logic of `C10_CUDA_CHECK()` is now contained within a new function `c10_cuda_check_implementation()` to which `C10_CUDA_CHECK` passes the preprocessor information about filenames, function names, and line numbers. (In C++20 we can use `std::source_location` to eliminate macros entirely!)

# Notes on special cases

* Multiple assertions from the same block are recorded
* Multiple assertions from different blocks are recorded
* Launching kernels from many threads on many streams seems to be handled correctly
* If two process are using the same GPU and one of the processes fails with a device-side assertion the other process continues without issue
* X Multiple assertions from separate kernels on different streams seem to be recorded, but we can't reproduce the test condition
* X Multiple assertions from separate devices should be all be shown upon exit, but we've been unable to generate a test that produces this condition

Differential Revision: D37621532

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84609
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Richard Barnes
2022-12-08 01:26:07 +00:00
committed by PyTorch MergeBot
parent f99f239531
commit ad188a227e
18 changed files with 1371 additions and 22 deletions

View File

@ -21,16 +21,18 @@ configure_file(
# and headers you add # and headers you add
set(C10_CUDA_SRCS set(C10_CUDA_SRCS
CUDACachingAllocator.cpp CUDACachingAllocator.cpp
CUDADeviceAssertionHost.cpp
CUDAException.cpp CUDAException.cpp
CUDAFunctions.cpp CUDAFunctions.cpp
CUDAMallocAsyncAllocator.cpp
CUDAMiscFunctions.cpp CUDAMiscFunctions.cpp
CUDAStream.cpp CUDAStream.cpp
CUDACachingAllocator.cpp
CUDAMallocAsyncAllocator.cpp
impl/CUDAGuardImpl.cpp impl/CUDAGuardImpl.cpp
impl/CUDATest.cpp impl/CUDATest.cpp
) )
set(C10_CUDA_HEADERS set(C10_CUDA_HEADERS
CUDACachingAllocator.h
CUDADeviceAssertionHost.h
CUDAException.h CUDAException.h
CUDAFunctions.h CUDAFunctions.h
CUDAGuard.h CUDAGuard.h
@ -38,7 +40,6 @@ set(C10_CUDA_HEADERS
CUDAMathCompat.h CUDAMathCompat.h
CUDAMiscFunctions.h CUDAMiscFunctions.h
CUDAStream.h CUDAStream.h
CUDACachingAllocator.h
impl/CUDAGuardImpl.h impl/CUDAGuardImpl.h
impl/CUDATest.h impl/CUDATest.h
) )

View File

@ -0,0 +1,98 @@
#pragma once
#include <c10/cuda/CUDAException.h>
#include <c10/macros/Macros.h>
namespace c10 {
namespace cuda {
#ifdef TORCH_USE_CUDA_DSA
// Copy string from `src` to `dst`
static __device__ void dstrcpy(char* dst, const char* src) {
int i = 0;
// Copy string from source to destination, ensuring that it
// isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
*dst++ = *src++;
}
*dst = '\0';
}
__device__ __noinline__ void dsa_add_new_assertion_failure(
DeviceAssertionsData* assertions_data,
const char* assertion_msg,
const char* filename,
const char* function_name,
const int line_number,
const uint32_t caller,
const dim3 block_id,
const dim3 thread_id) {
// `assertions_data` may be nullptr if device-side assertion checking
// is disabled at run-time. If it is disabled at compile time this
// function will never be called
if (!assertions_data) {
return;
}
// Atomically increment so other threads can fail at the same time
// Note that incrementing this means that the CPU can observe that
// a failure has happened and can begin to respond before we've
// written information about that failure out to the buffer.
const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
// At this point we're ran out of assertion buffer space.
// We could print a message about this, but that'd get
// spammy if a lot of threads did it, so we just silently
// ignore any other assertion failures. In most cases the
// failures will all probably be analogous anyway.
return;
}
// Write information about the assertion failure to memory.
// Note that this occurs only after the `assertion_count`
// increment broadcasts that there's been a problem.
auto& self = assertions_data->assertions[nid];
dstrcpy(self.assertion_msg, assertion_msg);
dstrcpy(self.filename, filename);
dstrcpy(self.function_name, function_name);
self.line_number = line_number;
self.caller = caller;
self.block_id[0] = block_id.x;
self.block_id[1] = block_id.y;
self.block_id[2] = block_id.z;
self.thread_id[0] = thread_id.x;
self.thread_id[1] = thread_id.y;
self.thread_id[2] = thread_id.z;
}
// Emulates a kernel assertion. The assertion won't stop the kernel's progress,
// so you should assume everything the kernel produces is garbage if there's an
// assertion failure.
// NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
// arguments of the kernel and therefore accessible.
#define CUDA_KERNEL_ASSERT2(condition) \
do { \
if (C10_UNLIKELY(!(condition))) { \
/* Has an atomic element so threads can fail at the same time */ \
c10::cuda::dsa_add_new_assertion_failure( \
assertions_data, \
C10_STRINGIZE(condition), \
__FILE__, \
__FUNCTION__, \
__LINE__, \
assertion_caller_id, \
blockIdx, \
threadIdx); \
/* Now that the kernel has failed we early exit the kernel, but */ \
/* otherwise keep going and rely on the host to check UVM and */ \
/* determine we've had a problem */ \
return; \
} \
} while (false)
#else
#define CUDA_KERNEL_ASSERT2(condition) assert(condition)
#endif
} // namespace cuda
} // namespace c10

View File

@ -0,0 +1,367 @@
#include <c10/cuda/CUDADeviceAssertionHost.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/Backtrace.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <string>
#include <thread>
#define CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS() \
c10_cuda_check_implementation(__FILE__, __FUNCTION__, __LINE__, false)
namespace c10 {
namespace cuda {
namespace {
/// Get the number of CUDA devices
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_count() {
int device_count = -1;
C10_CUDA_ERROR_HANDLED(cudaGetDeviceCount(&device_count));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return device_count;
}
bool dsa_check_if_all_devices_support_managed_memory() {
// It looks as though this'll work best on CUDA GPUs with Pascal
// architectures or newer, per
// https://developer.nvidia.com/blog/unified-memory-cuda-beginners/
#ifdef TORCH_USE_CUDA_DSA
for (const auto i : c10::irange(dsa_get_device_count())) {
if (dsa_get_device_compute_capability(i) < 6) {
return false;
}
}
return true;
#else
return false;
#endif
}
bool env_flag_set(const char* env_var_name) {
const char* const env_string = std::getenv(env_var_name);
return (env_string == nullptr) ? false : std::strcmp(env_string, "0");
}
/// Deleter for UVM/managed memory pointers
void uvm_deleter(DeviceAssertionsData* uvm_assertions_ptr) {
// Ignore error in destructor
if (uvm_assertions_ptr) {
C10_CUDA_IGNORE_ERROR(cudaFree(uvm_assertions_ptr));
}
}
#ifdef TORCH_USE_CUDA_DSA
/// Get current device id
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_id() {
int device = -1;
C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return device;
}
/// Get a device's compute capability - note that this dangerously assumes
/// that if one CUDA GPU supports device-side assertions they all do. This is
/// probably fine since the latest CUDA GPU that doesn't support UVM is the
/// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be
/// rare enough that the defensive
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_compute_capability(const int device_num) {
int compute_capability = -1;
C10_CUDA_ERROR_HANDLED(cudaDeviceGetAttribute(
&compute_capability, cudaDevAttrComputeCapabilityMajor, device_num));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return compute_capability;
}
#endif
} // namespace
/// Check that kernels ran correctly by checking the message buffer. BLOCKING.
std::string c10_retrieve_device_side_assertion_info() {
#ifdef TORCH_USE_CUDA_DSA
const auto& launch_registry = CUDAKernelLaunchRegistry::get_singleton_ref();
if (!launch_registry.enabled) {
return "Device-side assertion tracking was not enabled by user.";
} else if (!launch_registry.do_all_devices_support_managed_memory) {
return "Device-side assertions disabled because not all devices support managed memory.";
}
// Hack that saves a lot of challenging sync logic.
// The GPU increments the number of errors it's observed and the CPU can see
// that happening immediately which means we can make it here before the GPU
// is done writing information about those errors to memory.
// A short pause gives it time to finish. Since something's gone wrong, this
// pause shouldn't affect perf.
std::this_thread::sleep_for(std::chrono::seconds(1));
// The snapshot causes a brief block. That's okay because this function only
// executes if something's gone wrong such that speed is no longer a priority.
const auto launch_data = launch_registry.snapshot();
const auto& assertion_data = launch_data.first;
const auto& launch_infos = launch_data.second;
std::stringstream oss;
{
oss << "This process interacted the following GPUs = {";
bool first_gpu_listed = true;
for (const auto& x : uvm_assertions) {
if (x) {
if (!first_gpu_listed) {
oss << ","
}
first_gpu_listed = true;
oss << x;
}
}
oss << "}" << std::endl;
}
// Loop over each device that could be managed by the process
for (const auto device_num : c10::irange(assertion_data.size())) {
const auto& assertion_data_for_device = assertion_data.at(device_num);
// Did anything fail?
const auto failures_found = std::min(
assertion_data_for_device.assertion_count,
C10_CUDA_DSA_ASSERTION_COUNT);
if (failures_found == 0) {
continue;
}
// Something failed, let's talk about that
oss << failures_found
<< " CUDA device-side assertion failures were found on GPU #"
<< device_num << "!" << std::endl;
if (assertion_data_for_device.assertion_count >
C10_CUDA_DSA_ASSERTION_COUNT) {
oss << "But at least " << assertion_data_for_device.assertion_count
<< " assertion failures occurred on the device" << std::endl;
oss << "Adjust `C10_CUDA_DSA_ASSERTION_COUNT` if you need more assertion failure info"
<< std::endl;
}
for (const auto i : c10::irange(failures_found)) {
const auto& self = assertion_data_for_device.assertions[i];
const auto& launch_info = launch_infos[self.caller % launch_infos.size()];
oss << "Assertion failure " << i << std::endl;
oss << " GPU assertion failure message = " << self.assertion_msg
<< std::endl;
oss << " File containing assertion = " << self.filename << ":"
<< self.line_number << std::endl;
oss << " Device function containing assertion = " << self.function_name
<< std::endl;
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
<< self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
<< self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
if (launch_info.generation_number == self.caller) {
oss << " File containing kernel launch = "
<< launch_info.launch_filename << ":" << launch_info.launch_linenum
<< std::endl;
oss << " Function containing kernel launch = "
<< launch_info.launch_function << std::endl;
oss << " Name of kernel launched that led to failure = "
<< launch_info.kernel_name << std::endl;
oss << " Device that launched kernel = " << launch_info.device
<< std::endl;
oss << " Stream kernel was launched on = " << launch_info.stream
<< std::endl;
oss << " Backtrace of kernel launch site = ";
if (launch_registry.gather_launch_stacktrace) {
oss << "Launch stacktracing disabled." << std::endl;
} else {
oss << "\n" << launch_info.launch_stacktrace << std::endl;
}
} else {
oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `CUDAKernelLaunchRegistry::max_size`."
<< std::endl;
}
}
}
return oss.str();
#else
return "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n";
#endif
}
CUDAKernelLaunchRegistry::CUDAKernelLaunchRegistry()
: do_all_devices_support_managed_memory(
dsa_check_if_all_devices_support_managed_memory()),
gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()),
enabled(check_env_for_dsa_enabled()) {
for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) {
uvm_assertions.emplace_back(nullptr, uvm_deleter);
}
kernel_launches.resize(max_kernel_launches);
}
bool CUDAKernelLaunchRegistry::check_env_for_enable_launch_stacktracing()
const {
return env_flag_set("PYTORCH_CUDA_DSA_STACKTRACING");
}
bool CUDAKernelLaunchRegistry::check_env_for_dsa_enabled() const {
return env_flag_set("PYTORCH_USE_CUDA_DSA");
}
uint32_t CUDAKernelLaunchRegistry::insert(
const char* launch_filename,
const char* launch_function,
const uint32_t launch_linenum,
const char* kernel_name,
const int32_t stream_id) {
#ifdef TORCH_USE_CUDA_DSA
if (!is_enabled()) {
return 0;
}
const auto backtrace = gather_launch_stacktrace ? c10::get_backtrace() : "";
const std::lock_guard<std::mutex> lock(read_write_mutex);
const auto my_gen_number = generation_number++;
// TODO: It would probably be good to get a stack trace here so that
// we can better indicate which launch caused the failure.
kernel_launches[my_gen_number % max_kernel_launches] = {
launch_filename,
launch_function,
launch_linenum,
backtrace,
kernel_name,
dsa_get_device_id(),
stream_id,
my_gen_number};
return my_gen_number;
#else
return 0;
#endif
}
std::pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>>
CUDAKernelLaunchRegistry::snapshot() const {
// This is likely to be the longest-lasting hold on the mutex, but
// we only expect it to be called in cases where we're already failing
// and speed is no longer important
const std::lock_guard<std::mutex> lock(read_write_mutex);
std::vector<DeviceAssertionsData> device_assertions_data;
for (const auto& x : uvm_assertions) {
if (x) {
device_assertions_data.push_back(*x);
} else {
device_assertions_data.emplace_back();
}
}
return std::make_pair(device_assertions_data, kernel_launches);
}
DeviceAssertionsData* CUDAKernelLaunchRegistry::
get_uvm_assertions_ptr_for_current_device() {
#ifdef TORCH_USE_CUDA_DSA
if (!is_enabled()) {
return nullptr;
}
const auto device_num = dsa_get_device_id();
// If we've already set up this GPU with managed memory, return a pointer to
// the managed memory. This is a lock-free quick-return path.
if (uvm_assertions.at(device_num)) {
return uvm_assertions.at(device_num).get();
}
// Need a lock here so there's not race-condition on creating the new device
// assertions buffer
const std::lock_guard<std::mutex> lock(gpu_alloc_mutex);
// If we've already set up this GPU with managed memory, return a pointer to
// the managed memory. This locked path ensures that the device memory is
// allocated only once
if (uvm_assertions.at(device_num)) {
return uvm_assertions.at(device_num).get();
}
// Otherwise, set up the GPU to be able to use the device-side assertion
// system
DeviceAssertionsData* uvm_assertions_ptr = nullptr;
C10_CUDA_ERROR_HANDLED(
cudaMallocManaged(&uvm_assertions_ptr, sizeof(DeviceAssertionsData)));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
C10_CUDA_ERROR_HANDLED(cudaMemAdvise(
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetPreferredLocation,
cudaCpuDeviceId));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
// GPU will establish direct mapping of data in CPU memory, no page faults
// will be generated
C10_CUDA_ERROR_HANDLED(cudaMemAdvise(
uvm_assertions_ptr,
sizeof(DeviceAssertionsData),
cudaMemAdviseSetAccessedBy,
cudaCpuDeviceId));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
// Initialize the memory from the CPU; otherwise, pages may have to be created
// on demand. We think that UVM documentation indicates that first access may
// not honor preferred location, which would be bad, if true, because we want
// this memory on the host so we can access it post-assertion. Initializing
// this on the CPU helps ensure that that's where the memory will live.
*uvm_assertions_ptr = DeviceAssertionsData();
// Ownership and lifetime management of `uvm_assertions_ptr` now passes to the
// uvm_assertions unique_ptr vector
uvm_assertions.at(device_num).reset(uvm_assertions_ptr);
return uvm_assertions_ptr;
#else
return nullptr;
#endif
}
CUDAKernelLaunchRegistry& CUDAKernelLaunchRegistry::get_singleton_ref() {
static CUDAKernelLaunchRegistry launch_registry;
return launch_registry;
}
bool CUDAKernelLaunchRegistry::has_failed() const {
for (const auto& x : uvm_assertions) {
if (x && x->assertion_count > 0) {
return true;
}
}
return false;
}
bool CUDAKernelLaunchRegistry::is_enabled() const {
#ifdef TORCH_USE_CUDA_DSA
std::cerr << ""
#else
std::cerr
<< "TORCH_USE_CUDA_DSA not enabled in CUDAKernelLaunchRegistry::is_enabled"
<< std::endl;
return false;
#endif
}
} // namespace cuda
} // namespace c10

View File

@ -0,0 +1,156 @@
#pragma once
#include <c10/cuda/CUDAMacros.h>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#ifdef USE_CUDA
#define TORCH_USE_CUDA_DSA
#endif
/// Number of assertion failure messages we can store. If this is too small
/// threads will fail silently.
constexpr int C10_CUDA_DSA_ASSERTION_COUNT = 10;
constexpr int C10_CUDA_DSA_MAX_STR_LEN = 512;
namespace c10 {
namespace cuda {
/// Holds information about any device-side assertions that fail.
/// Held in managed memory and access by both the CPU and the GPU.
struct DeviceAssertionData {
/// Stringification of the assertion
char assertion_msg[C10_CUDA_DSA_MAX_STR_LEN];
/// File the assertion was in
char filename[C10_CUDA_DSA_MAX_STR_LEN];
/// Name of the function the assertion was in
char function_name[C10_CUDA_DSA_MAX_STR_LEN];
/// Line number the assertion was at
int line_number;
/// Number uniquely identifying the kernel launch that triggered the assertion
uint32_t caller;
/// block_id of the thread that failed the assertion
int32_t block_id[3];
/// third_id of the thread that failed the assertion
int32_t thread_id[3];
};
/// Used to hold assertions generated by the device
/// Held in managed memory and access by both the CPU and the GPU.
struct DeviceAssertionsData {
/// Total number of assertions found; a subset of thse will be recorded
/// in `assertions`
int32_t assertion_count;
/// An array of assertions that will be written to in a race-free manner
DeviceAssertionData assertions[C10_CUDA_DSA_ASSERTION_COUNT];
};
/// Use to hold info about kernel launches so that we can run kernels
/// asynchronously and still associate launches with device-side
/// assertion failures
struct CUDAKernelLaunchInfo {
/// Filename of the code where the kernel was launched from
const char* launch_filename;
/// Function from which the kernel was launched
const char* launch_function;
/// Line number of where the code was launched from
uint32_t launch_linenum;
/// Backtrace of where the kernel was launched from, only populated if
/// CUDAKernelLaunchRegistry::gather_launch_stacktrace is True
std::string launch_stacktrace;
/// Kernel that was launched
const char* kernel_name;
/// Device the kernel was launched on
int device;
/// Stream the kernel was launched on
int32_t stream;
/// A number that uniquely identifies the kernel launch
uint64_t generation_number;
};
/// Circular buffer used to hold information about kernel launches
/// this is later used to reconstruct how a device-side kernel assertion failure
/// occurred CUDAKernelLaunchRegistry is used as a singleton
class C10_CUDA_API CUDAKernelLaunchRegistry {
private:
/// Assume that this is the max number of kernel launches that might ever be
/// enqueued across all streams on a single device
static constexpr int max_kernel_launches = 1024;
/// How many kernel launch infos we've inserted. Used to ensure that circular
/// queue doesn't provide false information by always increasing, but also to
/// mark where we are inserting into the queue
#ifdef TORCH_USE_CUDA_DSA
uint64_t generation_number = 0;
#endif
/// Shared mutex between writer and accessor to ensure multi-threaded safety.
mutable std::mutex read_write_mutex;
/// Used to ensure prevent race conditions in GPU memory allocation
mutable std::mutex gpu_alloc_mutex;
/// Pointer to managed memory keeping track of device-side assertions. There
/// is one entry for each possible device the process might work with. Unused
/// entries are nullptrs. We could also use an unordered_set here, but this
/// vector design will be faster and the wasted memory is small since we
/// expect the number of GPUs per node will always be small
std::vector<
std::unique_ptr<DeviceAssertionsData, void (*)(DeviceAssertionsData*)>>
uvm_assertions;
/// A single circular buffer holds information about every kernel launch the
/// process makes across all devices.
std::vector<CUDAKernelLaunchInfo> kernel_launches;
bool check_env_for_enable_launch_stacktracing() const;
bool check_env_for_dsa_enabled() const;
public:
CUDAKernelLaunchRegistry();
/// Register a new kernel launch and obtain a generation number back to be
/// passed to the kernel
uint32_t insert(
const char* launch_filename,
const char* launch_function,
const uint32_t launch_linenum,
const char* kernel_name,
const int32_t stream_id);
/// Get copies of the kernel launch registry and each device's assertion
/// failure buffer so they can be inspected without raising race conditions
std::
pair<std::vector<DeviceAssertionsData>, std::vector<CUDAKernelLaunchInfo>>
snapshot() const;
/// Get a pointer to the current device's assertion failure buffer. If no such
/// buffer exists then one is created. This means that the first kernel launch
/// made on each device will be slightly slower because memory allocations are
/// required
DeviceAssertionsData* get_uvm_assertions_ptr_for_current_device();
/// Gets the global singleton of the registry
static CUDAKernelLaunchRegistry& get_singleton_ref();
/// If not all devices support DSA, we disable it
const bool do_all_devices_support_managed_memory = false;
/// Whether or not to gather stack traces when launching kernels
bool gather_launch_stacktrace = false;
/// Whether or not host-side DSA is enabled or disabled at run-time
/// Device-side code cannot be adjusted at run-time
bool enabled = false;
/// Whether or not a device has indicated a failure
bool has_failed() const;
/// Since multiple mechanisms can enable/disable, we add a function that
/// aggregates them
bool is_enabled() const;
};
std::string c10_retrieve_device_side_assertion_info();
} // namespace cuda
} // namespace c10
// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH
// requires the same input arguments. We introduce the following macro to
// standardize these.
#define TORCH_DSA_KERNEL_ARGS \
c10::cuda::DeviceAssertionsData *const assertions_data, \
uint32_t assertion_caller_id
// This macro can be used to pass the DSA arguments onward to another
// function
#define TORCH_DSA_KERNEL_ARGS_PASS assertions_data, assertion_caller_id

View File

@ -1,5 +1,6 @@
#include <c10/cuda/CUDAException.h> #include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDADeviceAssertionHost.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
@ -13,19 +14,27 @@ void c10_cuda_check_implementation(
const char* function_name, const char* function_name,
const int line_number, const int line_number,
const bool include_device_assertions) { const bool include_device_assertions) {
// We retrieve the error here in order to keep CUDA data types out of const auto cuda_error = cudaGetLastError();
// CUDAException.h thereby simplifying including it in other files const auto cuda_kernel_failure = include_device_assertions
const cudaError_t err = cudaGetLastError(); ? c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().has_failed()
: false;
if (C10_LIKELY(err == cudaSuccess)) { if (C10_LIKELY(cuda_error == cudaSuccess && !cuda_kernel_failure)) {
return; return;
} }
std::string check_message; std::string check_message;
#ifndef STRIP_ERROR_MESSAGES #ifndef STRIP_ERROR_MESSAGES
check_message.append("CUDA error: "); check_message.append("CUDA error: ");
check_message.append(cudaGetErrorString(err)); check_message.append(cudaGetErrorString(cuda_error));
check_message.append(c10::cuda::get_cuda_check_suffix()); check_message.append(c10::cuda::get_cuda_check_suffix());
check_message.append("\n");
if (include_device_assertions) {
check_message.append(c10_retrieve_device_side_assertion_info());
} else {
check_message.append(
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
}
#endif #endif
TORCH_CHECK(false, check_message); TORCH_CHECK(false, check_message);

View File

@ -1,9 +1,11 @@
#pragma once #pragma once
#include <c10/cuda/CUDADeviceAssertionHost.h>
#include <c10/cuda/CUDAMacros.h> #include <c10/cuda/CUDAMacros.h>
#include <c10/cuda/CUDAMiscFunctions.h> #include <c10/cuda/CUDAMiscFunctions.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <cuda.h> #include <cuda.h>
// Note [CHECK macro] // Note [CHECK macro]
@ -22,17 +24,17 @@ class C10_CUDA_API CUDAError : public c10::Error {
}; };
} // namespace c10 } // namespace c10
#define C10_CUDA_CHECK(EXPR) \ #define C10_CUDA_CHECK(EXPR) \
do { \ do { \
const cudaError_t __err = EXPR; \ /* We get & disarm the error inside of */ \
if (C10_UNLIKELY(__err != cudaSuccess)) { \ /* `c10_cuda_check_implementation` */ \
c10::cuda::c10_cuda_check_implementation( \ C10_UNUSED const cudaError_t __err = EXPR; \
__FILE__, \ c10::cuda::c10_cuda_check_implementation( \
__func__, /* Line number's data type is not well-defined between \ __FILE__, \
compilers, so we perform an explicit cast */ \ __func__, /* Line number's data type is not well-defined between \
static_cast<uint32_t>(__LINE__), \ compilers, so we perform an explicit cast */ \
true); \ static_cast<uint32_t>(__LINE__), \
} \ true); \
} while (0) } while (0)
#define C10_CUDA_CHECK_WARN(EXPR) \ #define C10_CUDA_CHECK_WARN(EXPR) \
@ -70,6 +72,21 @@ class C10_CUDA_API CUDAError : public c10::Error {
// diagnostic if it didn't. // diagnostic if it didn't.
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError()) #define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
/// Launches a CUDA kernel appending to it all the information need to handle
/// device-side assertion failures. Checks that the launch was successful.
#define TORCH_DSA_KERNEL_LAUNCH( \
kernel, blocks, threads, shared_mem, stream, ...) \
do { \
auto& launch_registry = \
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \
kernel<<<blocks, threads, shared_mem, stream>>>( \
__VA_ARGS__, \
launch_registry.get_uvm_assertions_ptr_for_current_device(), \
launch_registry.insert( \
__FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
} while (0)
namespace c10 { namespace c10 {
namespace cuda { namespace cuda {

View File

@ -11,7 +11,7 @@ namespace impl {
bool has_cuda_gpu() { bool has_cuda_gpu() {
int count; int count;
C10_CUDA_CHECK(cudaGetDeviceCount(&count)); C10_CUDA_IGNORE_ERROR(cudaGetDeviceCount(&count));
return count != 0; return count != 0;
} }

View File

@ -1,6 +1,13 @@
# ---[ Test binaries. # ---[ Test binaries.
set(C10_CUDA_ALL_TEST_FILES set(C10_CUDA_ALL_TEST_FILES
impl/CUDAAssertionsTest_1_var_test.cu
impl/CUDAAssertionsTest_catches_stream.cu
impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu
impl/CUDAAssertionsTest_from_2_processes.cu
impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu
impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu
impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu
impl/CUDATest.cpp impl/CUDATest.cpp
) )
if(BUILD_TEST) if(BUILD_TEST)

View File

@ -1,10 +1,42 @@
dsa_tests = [
"impl/CUDAAssertionsTest_1_var_test.cu",
"impl/CUDAAssertionsTest_catches_stream.cu",
"impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu",
"impl/CUDAAssertionsTest_from_2_processes.cu",
"impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu",
"impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu",
"impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu",
]
def define_targets(rules): def define_targets(rules):
rules.cc_test( rules.cc_test(
name = "test", name = "test",
srcs = ["impl/CUDATest.cpp"], srcs = [
"impl/CUDATest.cpp",
],
deps = [ deps = [
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
"//c10/cuda", "//c10/cuda",
], ],
target_compatible_with = rules.requires_cuda_enabled(), target_compatible_with = rules.requires_cuda_enabled(),
) )
for src in dsa_tests:
name = src.replace("impl/", "").replace(".cu", "")
rules.cuda_library(
name = "test_" + name + "_lib",
srcs = [
src,
],
deps = [
"@com_google_googletest//:gtest_main",
"//c10/cuda",
],
target_compatible_with = rules.requires_cuda_enabled(),
)
rules.cc_test(
name = "test_" + name,
deps = [
":test_" + name + "_lib",
],
)

View File

@ -0,0 +1,102 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
void did_not_fail_diagnostics() {
#ifdef TORCH_USE_CUDA_DSA
std::cerr << "DSA was enabled" << std::endl;
#else
std::cerr << "DSA was not enabled" << std::endl;
#endif
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled
<< std::endl;
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled() = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled()
<< std::endl;
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().do_all_devices_support_managed_memory = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref()
.do_all_devices_support_managed_memory
<< std::endl;
}
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* TEST: Triggering device side assertion on a simple <<<1,1>>> config.
* kernel used takes only 1 variable as parameter function.
*/
void cuda_device_assertions_1_var_test() {
const auto stream = c10::cuda::getStreamFromPool();
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_fail_assertion_kernel,
1, /* Blocks */
1, /* Threads */
0, /* Shared mem */
stream, /* Stream */
1);
try {
c10::cuda::device_synchronize();
did_not_fail_diagnostics();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(
err_str,
HasSubstr("CUDA device-side assertion failures were found on GPU #0!"));
ASSERT_THAT(
err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
ASSERT_THAT(
err_str,
HasSubstr(
"Function containing kernel launch = " +
std::string(__FUNCTION__)));
ASSERT_THAT(
err_str,
HasSubstr(
"Stream kernel was launched on = " + std::to_string(stream.id())));
}
}
TEST(CUDATest, cuda_device_assertions_1_var_test) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
std::cerr << "BEFORE TEST" << std::endl;
did_not_fail_diagnostics();
cuda_device_assertions_1_var_test();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -0,0 +1,101 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
/**
* Device kernel that takes mulitple integer parameters as arguments and
* will always trigger a device side assertion.
*/
__global__ void cuda_multiple_vars_always_fail_assertion_kernel(
const int a,
const int b,
const int c,
const int d,
TORCH_DSA_KERNEL_ARGS) {
int i = a + b + c + d;
if (i != 0) {
CUDA_KERNEL_ASSERT2(i == -i);
} else {
CUDA_KERNEL_ASSERT2(i == i + 1);
}
}
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* TEST: Triggering device side assertion on a simple <<<1,1>>> config.
* kernel used takes multiple variables as parameters to the function.
*/
void cuda_device_assertions_catches_stream() {
const auto stream = c10::cuda::getStreamFromPool();
TORCH_DSA_KERNEL_LAUNCH(
cuda_multiple_vars_always_fail_assertion_kernel,
1, /* Blocks */
1, /* Threads */
0, /* Shared mem */
stream, /* Stream */
1, /* const int a */
2, /* const int b */
3, /* const int c */
4 /* const int d */
);
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(
err_str, HasSubstr("# of GPUs this process interacted with = 1"));
ASSERT_THAT(
err_str,
HasSubstr("CUDA device-side assertion failures were found on GPU #0!"));
ASSERT_THAT(
err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_multiple_vars_always_fail_assertion_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
ASSERT_THAT(
err_str,
HasSubstr(
"Function containing kernel launch = " +
std::string(__FUNCTION__)));
ASSERT_THAT(
err_str,
HasSubstr(
"Stream kernel was launched on = " + std::to_string(stream.id())));
}
}
TEST(CUDATest, cuda_device_assertions_catches_stream) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_catches_stream();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -0,0 +1,86 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
/**
* Device kernel that takes 2 arguments
* @param bad_thread represents the thread we want to trigger assertion on.
* @param bad_block represents the block we want to trigger assertion on.
* This kernel will only trigger a device side assertion for <<bad_block,
* bad_thread>> pair. all the other blocks and threads pairs will basically be
* no-op.
*/
__global__ void cuda_device_assertions_fail_on_thread_block_kernel(
const int bad_thread,
const int bad_block,
TORCH_DSA_KERNEL_ARGS) {
if (threadIdx.x == bad_thread && blockIdx.x == bad_block) {
CUDA_KERNEL_ASSERT2(false); // This comparison necessarily needs to fail
}
}
/**
* TEST: Triggering device side assertion on only 1 thread from <<<1024,128>>>
* grid. kernel used is unique, it take 2 parameters to tell which particular
* block and thread it should assert, all the other theads of the kernel will be
* basically no-op.
*/
void cuda_device_assertions_catches_thread_and_block_and_device() {
const auto stream = c10::cuda::getStreamFromPool();
TORCH_DSA_KERNEL_LAUNCH(
cuda_device_assertions_fail_on_thread_block_kernel,
1024, /* Blocks */
128, /* Threads */
0, /* Shared mem */
stream, /* Stream */
29, /* bad thread */
937 /* bad block */
);
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(
err_str, HasSubstr("Thread ID that failed assertion = [29,0,0]"));
ASSERT_THAT(
err_str, HasSubstr("Block ID that failed assertion = [937,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_device_assertions_fail_on_thread_block_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
ASSERT_THAT(
err_str,
HasSubstr(
"Function containing kernel launch = " +
std::string(__FUNCTION__)));
ASSERT_THAT(
err_str,
HasSubstr(
"Stream kernel was launched on = " + std::to_string(stream.id())));
}
}
TEST(CUDATest, cuda_device_assertions_catches_thread_and_block_and_device) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_catches_thread_and_block_and_device();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -0,0 +1,108 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
const auto max_assertions_failure_str =
"Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* Device kernel that takes a single integer parameter as argument and
* will never trigger a device side assertion.
*/
__global__ void cuda_always_succeed_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a == a);
}
// Windows doesn't like `fork`
#ifndef _MSC_VER
/**
* TEST: Triggering device side assertion from 2 different processes from CPU.
* The following code is testing if two processes from CPU that are running
* GPU kernels (not necessarily simultaneously) and are asserting & writing
* to the respective UVMs, mess up anything for each other.
* Once parent process's kernel launch fails and causes a device-side assertion
* and is still alive when the second process is interacting with the GPU,
* trying to launch another kernel.
*/
void cuda_device_assertions_from_2_processes() {
const auto n1 = fork();
if (n1 == 0) {
// This is the parent process, that will call an assertion failure.
// This should execute before the child process.
// We are achieving this by putting the child process to sleep.
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_fail_assertion_kernel,
1, /* Blocks */
1, /* Threads */
0, /* Shared mem */
c10::cuda::getStreamFromPool(), /* Stream */
1);
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(
err_str,
HasSubstr(
"1 CUDA device-side assertion failures were found on GPU #0!"));
}
// Keep this alive so we can see what happened to the other process
std::this_thread::sleep_for(std::chrono::milliseconds(3000));
} else {
// This is the child process
// We put it to sleep for next 2 seconds, to make sure that the parent has
// asserted a failure already.
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_succeed_assertion_kernel,
1, /* Blocks */
1, /* Threads */
0, /* Shared mem */
c10::cuda::getStreamFromPool(), /* Stream */
1);
try {
c10::cuda::device_synchronize();
} catch (const c10::Error& err) {
ASSERT_TRUE(false); // This kernel should not have failed, but did.
}
// End the child process
exit(0);
}
}
TEST(CUDATest, cuda_device_assertions_from_2_processes) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_from_2_processes();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}
#else
#endif

View File

@ -0,0 +1,93 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
const auto max_assertions_failure_str =
"Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* TEST: Triggering device side assertion from multiple block but single thread
* <<<10,128>>>. Here we are triggering assertion on 10 blocks, each with only
* 128 thread.
*/
void cuda_device_assertions_multiple_writes_from_blocks_and_threads() {
bool run_threads = false;
// Create a function to launch kernel that waits for a signal, to try to
// ensure everything is happening simultaneously
const auto launch_the_kernel = [&]() {
// Busy loop waiting for the signal to go
while (!run_threads) {
}
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_fail_assertion_kernel,
10, /* Blocks */
128, /* Threads */
0, /* Shared mem */
c10::cuda::getCurrentCUDAStream(), /* Stream */
1);
};
// Spin up a bunch of busy-looping threads
std::vector<std::thread> threads;
for (int i = 0; i < 10; i++) {
threads.emplace_back(launch_the_kernel);
}
// Paranoid - wait for all the threads to get setup
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// Mash
run_threads = true;
// Clean-up
for (auto& x : threads) {
x.join();
}
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
}
}
TEST(CUDATest, cuda_device_assertions_multiple_writes_from_blocks_and_threads) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_blocks_and_threads();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -0,0 +1,90 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
const auto max_assertions_failure_str =
"Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* TEST: Triggering device side assertion from multiple block but single thread
* <<<10,1>>>. Here we are triggering assertion on 10 blocks, each with only 1
* thread. Since we have more than 10 SM on a GPU, we expect each block to be
* executed and successfully assert, Hence we will see assertions logged from
* each block here.
*/
void cuda_device_assertions_multiple_writes_from_multiple_blocks() {
const auto stream = c10::cuda::getStreamFromPool();
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_fail_assertion_kernel,
10, /* Blocks */
1, /* Threads */
0, /* Shared mem */
stream, /* Stream */
1);
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
ASSERT_THAT(
err_str, HasSubstr("Thread ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [1,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [2,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [3,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [4,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [5,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [6,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [7,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [8,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [9,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
ASSERT_THAT(
err_str,
HasSubstr(
"Function containing kernel launch = " +
std::string(__FUNCTION__)));
ASSERT_THAT(
err_str,
HasSubstr(
"Stream kernel was launched on = " + std::to_string(stream.id())));
}
}
TEST(CUDATest, cuda_device_assertions_multiple_writes_from_multiple_blocks) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_multiple_blocks();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -0,0 +1,78 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <c10/cuda/CUDADeviceAssertion.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAStream.h>
#include <chrono>
#include <iostream>
#include <string>
#include <thread>
using ::testing::HasSubstr;
const auto max_assertions_failure_str =
"Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
/**
* Device kernel that takes a single integer parameter as argument and
* will always trigger a device side assertion.
*/
__global__ void cuda_always_fail_assertion_kernel(
const int a,
TORCH_DSA_KERNEL_ARGS) {
CUDA_KERNEL_ASSERT2(a != a);
}
/**
* TEST: Triggering device side assertion from single block and multiple threads
* <<<1,128>>>. Once the very first thread asserts all the other threads will
* basically be in bad state and the block id with failed asseriton would be
* [0,0,0].
*/
void cuda_device_assertions_multiple_writes_from_same_block() {
const auto stream = c10::cuda::getStreamFromPool();
TORCH_DSA_KERNEL_LAUNCH(
cuda_always_fail_assertion_kernel,
1, /* Blocks */
128, /* Threads */
0, /* Shared mem */
stream, /* Stream */
1);
try {
c10::cuda::device_synchronize();
throw std::runtime_error("Test didn't fail, but should have.");
} catch (const c10::Error& err) {
const auto err_str = std::string(err.what());
ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
ASSERT_THAT(err_str, HasSubstr("Block ID that failed assertion = [0,0,0]"));
ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
ASSERT_THAT(
err_str,
HasSubstr(
"Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
ASSERT_THAT(
err_str, HasSubstr("File containing kernel launch = " __FILE__));
ASSERT_THAT(
err_str,
HasSubstr(
"Function containing kernel launch = " +
std::string(__FUNCTION__)));
ASSERT_THAT(
err_str,
HasSubstr(
"Stream kernel was launched on = " + std::to_string(stream.id())));
}
}
TEST(CUDATest, cuda_device_assertions_multiple_writes_from_same_block) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_same_block();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
#endif
}

View File

@ -1,5 +1,5 @@
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
load("@rules_cuda//cuda:defs.bzl", "requires_cuda_enabled") load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda_enabled")
load("//c10/macros:cmake_configure_file.bzl", "cmake_configure_file") load("//c10/macros:cmake_configure_file.bzl", "cmake_configure_file")
load("//tools/config:defs.bzl", "if_cuda") load("//tools/config:defs.bzl", "if_cuda")
@ -25,6 +25,7 @@ rules = struct(
cc_library = cc_library, cc_library = cc_library,
cc_test = cc_test, cc_test = cc_test,
cmake_configure_file = cmake_configure_file, cmake_configure_file = cmake_configure_file,
cuda_library = cuda_library,
filegroup = native.filegroup, filegroup = native.filegroup,
genrule = _genrule, genrule = _genrule,
glob = native.glob, glob = native.glob,

View File

@ -4128,6 +4128,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
"cudaStreamGetPriority", "cudaStreamGetPriority",
("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED), ("hipStreamGetPriority", CONV_STREAM, API_RUNTIME, HIP_UNSUPPORTED),
), ),
("cudaCpuDeviceId", ("hipCpuDeviceId", CONV_TYPE, API_RUNTIME)),
("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)), ("cudaStreamDefault", ("hipStreamDefault", CONV_TYPE, API_RUNTIME)),
("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)), ("cudaStreamNonBlocking", ("hipStreamNonBlocking", CONV_TYPE, API_RUNTIME)),
("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)), ("cudaDeviceSynchronize", ("hipDeviceSynchronize", CONV_DEVICE, API_RUNTIME)),
@ -8270,6 +8271,8 @@ C10_MAPPINGS = collections.OrderedDict(
[ [
("cuda::compat::", ("hip::compat::", API_C10)), ("cuda::compat::", ("hip::compat::", API_C10)),
("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)), ("c10/cuda/CUDAAlgorithm.h", ("c10/hip/HIPAlgorithm.h", API_C10)),
("c10/cuda/CUDADeviceAssertion.h", ("c10/hip/HIPDeviceAssertion.h", API_C10)),
("c10/cuda/CUDADeviceAssertionHost.h", ("c10/hip/HIPDeviceAssertionHost.h", API_C10)),
("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)), ("c10/cuda/CUDAException.h", ("c10/hip/HIPException.h", API_C10)),
("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)),
("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)),