mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[CUDA] Add experimental green context support for SM carveout (#159104)"
This reverts commit 746fe78ecd52f3e9cfddda41f0ac82dada7bdd0b. Reverted https://github.com/pytorch/pytorch/pull/159104 on behalf of https://github.com/malfet due to Breaks Windows CD build ([comment](https://github.com/pytorch/pytorch/pull/159104#issuecomment-3378675515))
This commit is contained in:
@ -51,17 +51,6 @@
|
||||
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||
_(cuCtxFromGreenCtx, 12080) \
|
||||
_(cuCtxGetCurrent, 12080) \
|
||||
_(cuCtxPopCurrent, 12080) \
|
||||
_(cuCtxPushCurrent, 12080) \
|
||||
_(cuCtxSetCurrent, 12080) \
|
||||
_(cuGreenCtxCreate, 12080) \
|
||||
_(cuGreenCtxDestroy, 12080) \
|
||||
_(cuDevSmResourceSplitByCount, 12080) \
|
||||
_(cuDeviceGet, 12080) \
|
||||
_(cuDeviceGetDevResource, 12080) \
|
||||
_(cuDevResourceGenerateDesc, 12080) \
|
||||
_(cuMulticastAddDevice, 12030) \
|
||||
_(cuMulticastBindMem, 12030) \
|
||||
_(cuMulticastCreate, 12030) \
|
||||
|
@ -262,28 +262,6 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
||||
|
||||
```
|
||||
|
||||
## Green Contexts (experimental)
|
||||
|
||||
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
||||
to enable more general carveout of SM resources for CUDA kernels.
|
||||
|
||||
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
||||
|
||||
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
GreenContext
|
||||
```
|
||||
|
||||
|
||||
% This module needs to be documented. Adding here in the meantime
|
||||
|
||||
% for tracking purposes
|
||||
@ -296,10 +274,6 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
|
||||
.. py:module:: torch.cuda.gds
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.green_contexts
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. py:module:: torch.cuda.jiterator
|
||||
```
|
||||
@ -325,4 +299,4 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
|
||||
:hidden:
|
||||
|
||||
cuda.aliases.md
|
||||
```
|
||||
```
|
@ -1,7 +1,6 @@
|
||||
# Owner(s): ["module: linear algebra"]
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
import unittest
|
||||
from itertools import product
|
||||
from functools import partial
|
||||
@ -842,28 +841,6 @@ class TestMatmulCuda(InductorTestCase):
|
||||
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||
|
||||
|
||||
@unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+")
|
||||
def test_greencontext_carveout(self):
|
||||
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
|
||||
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
|
||||
ctx.make_current()
|
||||
torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
partial_res = torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
ctx.pop_current()
|
||||
torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t2 = time.perf_counter()
|
||||
full_res = torch.matmul(a, a)
|
||||
torch.cuda.synchronize()
|
||||
t3 = time.perf_counter()
|
||||
self.assertEqual(partial_res, full_res)
|
||||
self.assertGreater(t1 - t0, t3 - t2)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||
|
@ -123,7 +123,6 @@ class TestPublicBindings(TestCase):
|
||||
"FutureType",
|
||||
"Generator",
|
||||
"GeneratorType",
|
||||
"GreenContext",
|
||||
"get_autocast_cpu_dtype",
|
||||
"get_autocast_dtype",
|
||||
"get_autocast_ipu_dtype",
|
||||
|
@ -2776,17 +2776,3 @@ class _StaticCudaLauncher:
|
||||
args: tuple[Any, ...],
|
||||
stream: _int,
|
||||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/green_context.h
|
||||
class GreenContext:
|
||||
@staticmethod
|
||||
def create(
|
||||
num_sms: _int,
|
||||
device_id: _int,
|
||||
) -> GreenContext: ...
|
||||
def make_current(
|
||||
self,
|
||||
) -> None: ...
|
||||
def pop_current(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
@ -23,7 +23,6 @@
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <torch/csrc/cuda/green_context.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
#ifdef USE_NCCL
|
||||
@ -1491,13 +1490,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
||||
addStorageDeleterFns(storages_to_add_deleters_to, delta);
|
||||
});
|
||||
}
|
||||
static void initGreenContext(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
py::class_<GreenContext>(m, "GreenContext")
|
||||
.def_static("create", &GreenContext::create)
|
||||
.def("make_current", &GreenContext::makeCurrent)
|
||||
.def("pop_current", &GreenContext::popCurrent);
|
||||
}
|
||||
|
||||
static void bindGetDeviceProperties(PyObject* module) {
|
||||
// Add method to torch.cuda
|
||||
@ -2222,7 +2214,6 @@ void initModule(PyObject* module) {
|
||||
registerCudaDeviceProperties(module);
|
||||
registerCudaPluggableAllocator(module);
|
||||
initCudaMethodBindings(module);
|
||||
initGreenContext(module);
|
||||
}
|
||||
|
||||
} // namespace torch::cuda
|
||||
|
@ -1,213 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#endif
|
||||
|
||||
class GreenContext {
|
||||
public:
|
||||
GreenContext(int device_id, unsigned int num_sms) {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::unique_ptr<GreenContext> create(
|
||||
unsigned int num_sms,
|
||||
std::optional<unsigned int> device_id) {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
GreenContext& operator=(GreenContext&& other) noexcept {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (current == context_) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attempting to overwrite current green ctx "
|
||||
"when it is active!");
|
||||
}
|
||||
C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_));
|
||||
}
|
||||
|
||||
// Take ownership of other's resources
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
||||
}
|
||||
return *this;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
~GreenContext() noexcept {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
CUgreenCtx getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void makeCurrent() {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(current_stream);
|
||||
|
||||
CUcontext current = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
||||
if (!current) {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
||||
} else {
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
||||
}
|
||||
// currently hardcodes the new green context to use the default stream
|
||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
||||
// with CUDA Graph captures etc.
|
||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
||||
ev.block(default_stream);
|
||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
void popCurrent() {
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
CUcontext popped;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
popped == context_, "expected popped context to be the current ctx");
|
||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
||||
int device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
@ -35,7 +35,6 @@ from .graphs import (
|
||||
is_current_stream_capturing,
|
||||
make_graphed_callables,
|
||||
)
|
||||
from .green_contexts import GreenContext
|
||||
from .streams import Event, ExternalStream, Stream
|
||||
|
||||
|
||||
@ -1845,7 +1844,6 @@ __all__ = [
|
||||
"ExternalStream",
|
||||
"Stream",
|
||||
"StreamContext",
|
||||
"GreenContext",
|
||||
"amp",
|
||||
"caching_allocator_alloc",
|
||||
"caching_allocator_delete",
|
||||
|
@ -1,42 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
_GreenContext = object
|
||||
SUPPORTED = False
|
||||
|
||||
if hasattr(torch._C, "GreenContext"):
|
||||
_GreenContext = torch._C.GreenContext # type: ignore[misc]
|
||||
SUPPORTED = True
|
||||
|
||||
|
||||
# Python shim helps Sphinx process docstrings more reliably.
|
||||
class GreenContext(_GreenContext):
|
||||
r"""Wrapper around a CUDA green context.
|
||||
|
||||
.. warning::
|
||||
This API is in beta and may change in future releases.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
|
||||
r"""Create a CUDA green context.
|
||||
|
||||
Arguments:
|
||||
num_sms (int): The number of SMs to use in the green context.
|
||||
device_id (int, optional): The device index of green context.
|
||||
"""
|
||||
if not SUPPORTED:
|
||||
raise RuntimeError("PyTorch was not built with Green Context support!")
|
||||
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
|
||||
|
||||
# Note that these functions are bypassed by we define them here
|
||||
# for Sphinx documentation purposes
|
||||
def make_current(self) -> None:
|
||||
r"""Make the green context the current context."""
|
||||
return super().make_current() # type: ignore[misc]
|
||||
|
||||
def pop_current(self) -> None:
|
||||
r"""Assuming the green context is the current context, pop it from the
|
||||
context stack and restore the previous context.
|
||||
"""
|
||||
return super().pop_current() # type: ignore[misc]
|
Reference in New Issue
Block a user