mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[CUDA] Add experimental green context support for SM carveout (#159104)"
This reverts commit 746fe78ecd52f3e9cfddda41f0ac82dada7bdd0b. Reverted https://github.com/pytorch/pytorch/pull/159104 on behalf of https://github.com/malfet due to Breaks Windows CD build ([comment](https://github.com/pytorch/pytorch/pull/159104#issuecomment-3378675515))
This commit is contained in:
@ -51,17 +51,6 @@
|
|||||||
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
|
||||||
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \
|
||||||
_(cuCtxFromGreenCtx, 12080) \
|
|
||||||
_(cuCtxGetCurrent, 12080) \
|
|
||||||
_(cuCtxPopCurrent, 12080) \
|
|
||||||
_(cuCtxPushCurrent, 12080) \
|
|
||||||
_(cuCtxSetCurrent, 12080) \
|
|
||||||
_(cuGreenCtxCreate, 12080) \
|
|
||||||
_(cuGreenCtxDestroy, 12080) \
|
|
||||||
_(cuDevSmResourceSplitByCount, 12080) \
|
|
||||||
_(cuDeviceGet, 12080) \
|
|
||||||
_(cuDeviceGetDevResource, 12080) \
|
|
||||||
_(cuDevResourceGenerateDesc, 12080) \
|
|
||||||
_(cuMulticastAddDevice, 12030) \
|
_(cuMulticastAddDevice, 12030) \
|
||||||
_(cuMulticastBindMem, 12030) \
|
_(cuMulticastBindMem, 12030) \
|
||||||
_(cuMulticastCreate, 12030) \
|
_(cuMulticastCreate, 12030) \
|
||||||
|
@ -262,28 +262,6 @@ See the docs for {class}`~torch.cuda.gds.GdsFile` for an example of how to use t
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Green Contexts (experimental)
|
|
||||||
|
|
||||||
`torch.cuda.green_contexts` provides thin wrappers around the CUDA Green Context APIs
|
|
||||||
to enable more general carveout of SM resources for CUDA kernels.
|
|
||||||
|
|
||||||
These APIs can be used in PyTorch with CUDA versions greater than or equal to 12.8.
|
|
||||||
|
|
||||||
See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example of how to use these.
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. currentmodule:: torch.cuda.green_contexts
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. autosummary::
|
|
||||||
:toctree: generated
|
|
||||||
:nosignatures:
|
|
||||||
|
|
||||||
GreenContext
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
% This module needs to be documented. Adding here in the meantime
|
% This module needs to be documented. Adding here in the meantime
|
||||||
|
|
||||||
% for tracking purposes
|
% for tracking purposes
|
||||||
@ -296,10 +274,6 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
|
|||||||
.. py:module:: torch.cuda.gds
|
.. py:module:: torch.cuda.gds
|
||||||
```
|
```
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. py:module:: torch.cuda.green_contexts
|
|
||||||
```
|
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. py:module:: torch.cuda.jiterator
|
.. py:module:: torch.cuda.jiterator
|
||||||
```
|
```
|
||||||
@ -325,4 +299,4 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example
|
|||||||
:hidden:
|
:hidden:
|
||||||
|
|
||||||
cuda.aliases.md
|
cuda.aliases.md
|
||||||
```
|
```
|
@ -1,7 +1,6 @@
|
|||||||
# Owner(s): ["module: linear algebra"]
|
# Owner(s): ["module: linear algebra"]
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -842,28 +841,6 @@ class TestMatmulCuda(InductorTestCase):
|
|||||||
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
op(a, mismatch_batch_dim_b, out_dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not _get_torch_cuda_version() >= (12, 8), "Green Context only tested on 12.8+")
|
|
||||||
def test_greencontext_carveout(self):
|
|
||||||
a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
|
|
||||||
ctx = torch.cuda.green_contexts.GreenContext.create(1, 0)
|
|
||||||
ctx.make_current()
|
|
||||||
torch.matmul(a, a)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
partial_res = torch.matmul(a, a)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t1 = time.perf_counter()
|
|
||||||
ctx.pop_current()
|
|
||||||
torch.matmul(a, a)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t2 = time.perf_counter()
|
|
||||||
full_res = torch.matmul(a, a)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
t3 = time.perf_counter()
|
|
||||||
self.assertEqual(partial_res, full_res)
|
|
||||||
self.assertGreater(t1 - t0, t3 - t2)
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||||
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
||||||
|
@ -123,7 +123,6 @@ class TestPublicBindings(TestCase):
|
|||||||
"FutureType",
|
"FutureType",
|
||||||
"Generator",
|
"Generator",
|
||||||
"GeneratorType",
|
"GeneratorType",
|
||||||
"GreenContext",
|
|
||||||
"get_autocast_cpu_dtype",
|
"get_autocast_cpu_dtype",
|
||||||
"get_autocast_dtype",
|
"get_autocast_dtype",
|
||||||
"get_autocast_ipu_dtype",
|
"get_autocast_ipu_dtype",
|
||||||
|
@ -2776,17 +2776,3 @@ class _StaticCudaLauncher:
|
|||||||
args: tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
stream: _int,
|
stream: _int,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
# Defined in torch/csrc/cuda/green_context.h
|
|
||||||
class GreenContext:
|
|
||||||
@staticmethod
|
|
||||||
def create(
|
|
||||||
num_sms: _int,
|
|
||||||
device_id: _int,
|
|
||||||
) -> GreenContext: ...
|
|
||||||
def make_current(
|
|
||||||
self,
|
|
||||||
) -> None: ...
|
|
||||||
def pop_current(
|
|
||||||
self,
|
|
||||||
) -> None: ...
|
|
||||||
|
@ -23,7 +23,6 @@
|
|||||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||||
#include <c10/cuda/CUDACachingAllocator.h>
|
#include <c10/cuda/CUDACachingAllocator.h>
|
||||||
#include <c10/cuda/CUDAFunctions.h>
|
#include <c10/cuda/CUDAFunctions.h>
|
||||||
#include <torch/csrc/cuda/green_context.h>
|
|
||||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||||
|
|
||||||
#ifdef USE_NCCL
|
#ifdef USE_NCCL
|
||||||
@ -1491,13 +1490,6 @@ static void registerCudaPluggableAllocator(PyObject* module) {
|
|||||||
addStorageDeleterFns(storages_to_add_deleters_to, delta);
|
addStorageDeleterFns(storages_to_add_deleters_to, delta);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
static void initGreenContext(PyObject* module) {
|
|
||||||
auto m = py::handle(module).cast<py::module>();
|
|
||||||
py::class_<GreenContext>(m, "GreenContext")
|
|
||||||
.def_static("create", &GreenContext::create)
|
|
||||||
.def("make_current", &GreenContext::makeCurrent)
|
|
||||||
.def("pop_current", &GreenContext::popCurrent);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void bindGetDeviceProperties(PyObject* module) {
|
static void bindGetDeviceProperties(PyObject* module) {
|
||||||
// Add method to torch.cuda
|
// Add method to torch.cuda
|
||||||
@ -2222,7 +2214,6 @@ void initModule(PyObject* module) {
|
|||||||
registerCudaDeviceProperties(module);
|
registerCudaDeviceProperties(module);
|
||||||
registerCudaPluggableAllocator(module);
|
registerCudaPluggableAllocator(module);
|
||||||
initCudaMethodBindings(module);
|
initCudaMethodBindings(module);
|
||||||
initGreenContext(module);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::cuda
|
} // namespace torch::cuda
|
||||||
|
@ -1,213 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <ATen/cuda/CUDAEvent.h>
|
|
||||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM)
|
|
||||||
#include <c10/cuda/driver_api.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <memory>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <vector>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
class GreenContext {
|
|
||||||
public:
|
|
||||||
GreenContext(int device_id, unsigned int num_sms) {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
int driver_version;
|
|
||||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
|
||||||
TORCH_CHECK(
|
|
||||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
|
||||||
CUcontext pctx = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
|
||||||
if (C10_UNLIKELY(!pctx)) {
|
|
||||||
TORCH_WARN(
|
|
||||||
"Attempted to create a green context but"
|
|
||||||
" there was no primary context! Creating a primary context...");
|
|
||||||
|
|
||||||
cudaFree(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
CUdevice device;
|
|
||||||
device_id_ = device_id;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
|
||||||
|
|
||||||
// Get device resources
|
|
||||||
CUdevResource device_resource;
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
|
||||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
|
||||||
|
|
||||||
// Split resources
|
|
||||||
std::vector<CUdevResource> result(1);
|
|
||||||
auto result_data = result.data();
|
|
||||||
unsigned int nb_groups = 1;
|
|
||||||
CUdevResource remaining;
|
|
||||||
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
|
||||||
result_data,
|
|
||||||
&nb_groups,
|
|
||||||
&device_resource,
|
|
||||||
&remaining,
|
|
||||||
0, // default flags
|
|
||||||
num_sms));
|
|
||||||
|
|
||||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
|
||||||
|
|
||||||
// Generate resource descriptor
|
|
||||||
CUdevResourceDesc desc;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
|
||||||
&desc, result_data, 1));
|
|
||||||
|
|
||||||
// Create green context
|
|
||||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
|
||||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
|
||||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
|
||||||
|
|
||||||
// Convert to regular context
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
|
||||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::unique_ptr<GreenContext> create(
|
|
||||||
unsigned int num_sms,
|
|
||||||
std::optional<unsigned int> device_id) {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
if (!device_id.has_value()) {
|
|
||||||
device_id = at::cuda::current_device();
|
|
||||||
}
|
|
||||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete copy constructor and assignment
|
|
||||||
GreenContext(const GreenContext&) = delete;
|
|
||||||
GreenContext& operator=(const GreenContext&) = delete;
|
|
||||||
|
|
||||||
// Implement move operations
|
|
||||||
GreenContext(GreenContext&& other) noexcept {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
device_id_ = std::exchange(other.device_id_, -1);
|
|
||||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
|
||||||
context_ = std::exchange(other.context_, nullptr);
|
|
||||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
GreenContext& operator=(GreenContext&& other) noexcept {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
if (this != &other) {
|
|
||||||
// Clean up current resources
|
|
||||||
if (green_ctx_) {
|
|
||||||
CUcontext current = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
|
||||||
if (current == context_) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
false,
|
|
||||||
"attempting to overwrite current green ctx "
|
|
||||||
"when it is active!");
|
|
||||||
}
|
|
||||||
C10_CUDA_DRIVER_CHECK(cuGreenCtxDestroy(green_ctx_));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take ownership of other's resources
|
|
||||||
device_id_ = std::exchange(other.device_id_, -1);
|
|
||||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
|
||||||
context_ = std::exchange(other.context_, nullptr);
|
|
||||||
parent_stream_ = std::exchange(other.parent_stream_, nullptr);
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
~GreenContext() noexcept {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the underlying CUDA context
|
|
||||||
CUcontext getContext() const {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
return context_;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the underlying green context
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
CUgreenCtx getGreenContext() const {
|
|
||||||
return green_ctx_;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Make this context current
|
|
||||||
void makeCurrent() {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
|
||||||
parent_stream_ = current_stream.stream();
|
|
||||||
|
|
||||||
at::cuda::CUDAEvent ev;
|
|
||||||
ev.record(current_stream);
|
|
||||||
|
|
||||||
CUcontext current = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(¤t));
|
|
||||||
if (!current) {
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxSetCurrent_(context_));
|
|
||||||
} else {
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxPushCurrent_(context_));
|
|
||||||
}
|
|
||||||
// currently hardcodes the new green context to use the default stream
|
|
||||||
// TODO(eqy): consider creating a new stream if e.g., it allows interop
|
|
||||||
// with CUDA Graph captures etc.
|
|
||||||
auto default_stream = c10::cuda::getDefaultCUDAStream();
|
|
||||||
ev.block(default_stream);
|
|
||||||
c10::cuda::setCurrentCUDAStream(default_stream);
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
void popCurrent() {
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
// see above note about stream being hardcoded to the default stream
|
|
||||||
at::cuda::CUDAEvent ev;
|
|
||||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
|
||||||
CUcontext popped;
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxPopCurrent_(&popped));
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
|
||||||
popped == context_, "expected popped context to be the current ctx");
|
|
||||||
ev.block(c10::cuda::getStreamFromExternal(parent_stream_, device_id_));
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12080 && !defined(USE_ROCM)
|
|
||||||
int device_id_ = -1;
|
|
||||||
CUgreenCtx green_ctx_ = nullptr;
|
|
||||||
CUcontext context_ = nullptr;
|
|
||||||
cudaStream_t parent_stream_ = nullptr;
|
|
||||||
#endif
|
|
||||||
};
|
|
@ -35,7 +35,6 @@ from .graphs import (
|
|||||||
is_current_stream_capturing,
|
is_current_stream_capturing,
|
||||||
make_graphed_callables,
|
make_graphed_callables,
|
||||||
)
|
)
|
||||||
from .green_contexts import GreenContext
|
|
||||||
from .streams import Event, ExternalStream, Stream
|
from .streams import Event, ExternalStream, Stream
|
||||||
|
|
||||||
|
|
||||||
@ -1845,7 +1844,6 @@ __all__ = [
|
|||||||
"ExternalStream",
|
"ExternalStream",
|
||||||
"Stream",
|
"Stream",
|
||||||
"StreamContext",
|
"StreamContext",
|
||||||
"GreenContext",
|
|
||||||
"amp",
|
"amp",
|
||||||
"caching_allocator_alloc",
|
"caching_allocator_alloc",
|
||||||
"caching_allocator_delete",
|
"caching_allocator_delete",
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
_GreenContext = object
|
|
||||||
SUPPORTED = False
|
|
||||||
|
|
||||||
if hasattr(torch._C, "GreenContext"):
|
|
||||||
_GreenContext = torch._C.GreenContext # type: ignore[misc]
|
|
||||||
SUPPORTED = True
|
|
||||||
|
|
||||||
|
|
||||||
# Python shim helps Sphinx process docstrings more reliably.
|
|
||||||
class GreenContext(_GreenContext):
|
|
||||||
r"""Wrapper around a CUDA green context.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
This API is in beta and may change in future releases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(num_sms: int, device_id: int = 0) -> _GreenContext:
|
|
||||||
r"""Create a CUDA green context.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
num_sms (int): The number of SMs to use in the green context.
|
|
||||||
device_id (int, optional): The device index of green context.
|
|
||||||
"""
|
|
||||||
if not SUPPORTED:
|
|
||||||
raise RuntimeError("PyTorch was not built with Green Context support!")
|
|
||||||
return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
# Note that these functions are bypassed by we define them here
|
|
||||||
# for Sphinx documentation purposes
|
|
||||||
def make_current(self) -> None:
|
|
||||||
r"""Make the green context the current context."""
|
|
||||||
return super().make_current() # type: ignore[misc]
|
|
||||||
|
|
||||||
def pop_current(self) -> None:
|
|
||||||
r"""Assuming the green context is the current context, pop it from the
|
|
||||||
context stack and restore the previous context.
|
|
||||||
"""
|
|
||||||
return super().pop_current() # type: ignore[misc]
|
|
Reference in New Issue
Block a user