mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adding JIT support for cuda streams and events (#48020)
Summary: ======= This PR addresses the following: * Adds JIT support for CUDA Streams * Adds JIT support for CUDA Events * Adds JIT support for CUDA Stream context manager Testing: ====== python test/test_jit.py -v TestCUDA Pull Request resolved: https://github.com/pytorch/pytorch/pull/48020 Reviewed By: navahgar Differential Revision: D25725749 Pulled By: nikithamalgifb fbshipit-source-id: b0addeb49630f8f0c430ed7badeca43bb9d2535c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
97c17b4772
commit
12b73fdbbf
@ -17,6 +17,7 @@ namespace c10 {
|
||||
#define FORALL_NS_SYMBOLS(_) \
|
||||
_(namespaces, prim) \
|
||||
_(namespaces, aten) \
|
||||
_(namespaces, cuda) \
|
||||
_(namespaces, onnx) \
|
||||
_(namespaces, attr) \
|
||||
_(namespaces, scope) \
|
||||
@ -284,6 +285,9 @@ namespace c10 {
|
||||
_(aten, zero_) \
|
||||
_(aten, fill_) \
|
||||
_(aten, masked_fill_) \
|
||||
_(cuda, _set_device) \
|
||||
_(cuda, set_stream) \
|
||||
_(cuda, _current_device) \
|
||||
_(aten, swapaxes) \
|
||||
_(aten, swapaxes_) \
|
||||
_(aten, swapdims) \
|
||||
@ -383,6 +387,7 @@ namespace c10 {
|
||||
#define FORALL_NS_SYMBOLS(_) \
|
||||
_(namespaces, prim) \
|
||||
_(namespaces, aten) \
|
||||
_(namespaces, cuda) \
|
||||
_(namespaces, onnx) \
|
||||
_(namespaces, attr) \
|
||||
_(namespaces, scope) \
|
||||
@ -453,6 +458,7 @@ struct TORCH_API Symbol {
|
||||
// (and if it's not, you should add it to the built-ins list above.)
|
||||
static Symbol attr(const std::string & s);
|
||||
static Symbol aten(const std::string & s);
|
||||
static Symbol cuda(const std::string & s);
|
||||
static Symbol onnx(const std::string & s);
|
||||
static Symbol prim(const std::string & s);
|
||||
static Symbol user(const std::string & s);
|
||||
@ -463,6 +469,7 @@ struct TORCH_API Symbol {
|
||||
|
||||
bool is_attr() const;
|
||||
bool is_aten() const;
|
||||
bool is_cuda() const;
|
||||
bool is_prim() const;
|
||||
bool is_onnx() const;
|
||||
bool is_user() const;
|
||||
@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
|
||||
|
||||
inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
|
||||
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
|
||||
inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
|
||||
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
|
||||
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
|
||||
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
|
||||
@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
|
||||
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
|
||||
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
|
||||
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
|
||||
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
|
||||
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
|
||||
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
|
||||
inline bool Symbol::is_user() const { return ns() == namespaces::user; }
|
||||
|
||||
@ -120,5 +120,33 @@ TEST(SerializationTest, TypeTags) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, TestJitStream_CUDA) {
|
||||
torch::jit::Module model;
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
// Deserialize the ScriptModule from a file using torch::jit::load().
|
||||
// Load the scripted model. This should have been generated by tests_setup.py
|
||||
// Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
|
||||
model = torch::jit::load("saved_stream_model.pt");
|
||||
|
||||
auto output = model.forward(inputs);
|
||||
auto list_of_elements = output.toTuple()->elements();
|
||||
auto is_stream_s = list_of_elements[0].toBool();
|
||||
|
||||
// a,b: These are the two input tensors
|
||||
// c: This is output tensor generated by the operation torch.cat(a,b)
|
||||
auto a = list_of_elements[1].toTensor();
|
||||
auto b = list_of_elements[2].toTensor();
|
||||
auto c = list_of_elements[3].toTensor();
|
||||
// op: this is used to verify if the cat operation produced the same results
|
||||
// as that on the GPU with torch.cat
|
||||
auto op = at::cat({a, b}, 0);
|
||||
|
||||
// Check if the stream is set
|
||||
ASSERT_TRUE(is_stream_s);
|
||||
// Check if the sizes of the outputs (op and c) is same on the GPU and CPU
|
||||
ASSERT_EQ(op.sizes(), c.sizes());
|
||||
// Check if both the output tensors are equal
|
||||
ASSERT_TRUE(op.equal(c));
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -63,11 +63,38 @@ class TorchSaveError(FileSetup):
|
||||
|
||||
torch.save(value, self.path, _use_new_zipfile_serialization=False)
|
||||
|
||||
class TorchSaveJitStream_CUDA(FileSetup):
|
||||
path = 'saved_stream_model.pt'
|
||||
|
||||
def setup(self):
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self):
|
||||
device_index = torch.cuda._current_device()
|
||||
s = torch.jit.cuda.Stream(device_index, 0)
|
||||
a = torch.rand(3, 4, device="cuda")
|
||||
b = torch.rand(3, 4, device="cuda")
|
||||
|
||||
with torch.jit.cuda.stream(s):
|
||||
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
c = torch.cat((a, b), 0).to("cuda")
|
||||
s.synchronize()
|
||||
return is_stream_s, a, b, c
|
||||
|
||||
model = Model()
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
torch.jit.save(script_model, self.path)
|
||||
|
||||
|
||||
tests = [
|
||||
EvalModeForLoadedModule(),
|
||||
SerializationInterop(),
|
||||
TorchSaveError(),
|
||||
TorchSaveJitStream_CUDA()
|
||||
]
|
||||
|
||||
def setup():
|
||||
|
||||
476
test/jit/test_cuda.py
Normal file
476
test/jit/test_cuda.py
Normal file
@ -0,0 +1,476 @@
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
# Check if GPU is available
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
# Check if multiple GPU's are available
|
||||
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
||||
|
||||
# If GPU is not available, then do not run the tests
|
||||
if not TEST_CUDA:
|
||||
print('CUDA not available, skipping tests', file=sys.stderr)
|
||||
JitTestCase = object # noqa: F811
|
||||
|
||||
TEST_LARGE_TENSOR = TEST_CUDA
|
||||
|
||||
# If GPU is available, then initialize the cuda context and check
|
||||
# if there is memory available to allocate for LARGE Tensors.
|
||||
if TEST_CUDA:
|
||||
torch.ones(1).cuda() # initialize cuda context
|
||||
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
class TestCUDA(JitTestCase):
|
||||
"""
|
||||
A suite of tests for the CUDA API in TorchScript.
|
||||
"""
|
||||
def setUp(self):
|
||||
super(TestCUDA, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
super(TestCUDA, self).tearDown()
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
def test_current_stream(self):
|
||||
# Test current stream on the device and check if the stream device index
|
||||
# matches with the device ID
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
device_index = torch.cuda._current_device()
|
||||
s0 = torch.cuda.current_stream(device_index)
|
||||
s1 = torch.cuda.current_stream(1)
|
||||
s2 = torch.cuda.current_stream(0)
|
||||
|
||||
return s0.device_index(), s1.device_index(), s2.device_index()
|
||||
|
||||
d0, d1, d2 = fn()
|
||||
|
||||
# By default, the current device ID is 0.
|
||||
self.assertEqual(0, d0)
|
||||
self.assertEqual(1, d1)
|
||||
self.assertEqual(0, d2)
|
||||
self.assertEqual(d0, d2)
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
|
||||
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
|
||||
@skipCUDANonDefaultStreamIf(True)
|
||||
def test_streams_and_events(self):
|
||||
# This test checks for the default stream ID is set to 0 on the device
|
||||
@torch.jit.script
|
||||
def test_default_streams():
|
||||
s0 = torch.cuda.default_stream(0)
|
||||
s1 = torch.cuda.default_stream(1)
|
||||
|
||||
d = torch.device('cuda:1')
|
||||
|
||||
# Check the current stream id and default id are same
|
||||
# on the current device. The current device id by default is 0
|
||||
s2 = torch.cuda.current_stream(0)
|
||||
check_s2 = s2.id() == s0.id()
|
||||
check_d0 = torch.cuda._current_device() == s2.device_index()
|
||||
|
||||
# Set the current device to d1 and check if the stream
|
||||
# has been set to the default stream on d1
|
||||
with torch.jit.cuda.device(d):
|
||||
s3 = torch.cuda.current_stream(1)
|
||||
check_s3 = s3.id() == s1.id()
|
||||
check_d1 = torch.cuda._current_device() == s3.device_index()
|
||||
|
||||
# Check if the current device was reset to 0
|
||||
is_device_d0 = torch.cuda._current_device() == s2.device_index()
|
||||
|
||||
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
|
||||
|
||||
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
|
||||
|
||||
self.assertEqual(d0, 0)
|
||||
self.assertEqual(d1, 1)
|
||||
self.assertTrue(check_s2)
|
||||
self.assertTrue(check_s3)
|
||||
self.assertTrue(check_d0)
|
||||
self.assertTrue(check_d1)
|
||||
self.assertTrue(is_device_d0)
|
||||
|
||||
# This test checks if the Stream Context manager is a no op
|
||||
# when the stream is none for `with torch.jit.cuda.stream`
|
||||
@torch.jit.script
|
||||
def test_set_none_stream():
|
||||
device_index = torch.cuda._current_device()
|
||||
current_stream = torch.cuda.current_stream(device_index)
|
||||
default_stream = torch.cuda.default_stream(device_index)
|
||||
|
||||
# When stream is none, check if this operation is a no-op
|
||||
with torch.jit.cuda.stream(None):
|
||||
cur_device_index = torch.cuda._current_device()
|
||||
is_device_index_same = cur_device_index == device_index
|
||||
is_current_stream_same = torch.cuda.current_stream(cur_device_index).id() == current_stream.id()
|
||||
is_default_stream_same = torch.cuda.default_stream(device_index).id() == default_stream.id()
|
||||
|
||||
# Check if the device index, current stream and default streams have not changed
|
||||
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
|
||||
return are_streams_same
|
||||
self.assertTrue(test_set_none_stream())
|
||||
|
||||
# This test checks if the Device Context manager is a no op
|
||||
# when the device is none for `with torch.jit.cuda.device`
|
||||
@torch.jit.script
|
||||
def test_set_device_none():
|
||||
device_index = torch.cuda._current_device()
|
||||
# When device is none, check if this operation is a no-op
|
||||
with torch.jit.cuda.device(None):
|
||||
# Check if the current device is the same
|
||||
is_device_same = torch.cuda._current_device() == device_index
|
||||
return is_device_same
|
||||
self.assertTrue(test_set_device_none())
|
||||
|
||||
# Check if a CUDA JIT stream is created
|
||||
# on the _current_device
|
||||
@torch.jit.script
|
||||
def test_simple_stream():
|
||||
device_index = torch.cuda._current_device()
|
||||
s = torch.jit.cuda.Stream(device_index, 0)
|
||||
return device_index == s.device_index()
|
||||
|
||||
self.assertTrue(test_simple_stream(), "Could not create Stream!")
|
||||
|
||||
# Class used to store results for the test: test_get_stream.
|
||||
class Result(NamedTuple):
|
||||
t1 : torch.Tensor
|
||||
t2 : torch.Tensor
|
||||
is_current_and_default_stream_same : bool
|
||||
is_default_and_user_stream_not_same : bool
|
||||
is_stream_set : bool
|
||||
is_stream_reset : bool
|
||||
default_stream_query : bool
|
||||
default_stream_id : int
|
||||
user_stream_id : int
|
||||
|
||||
# The test aims at checking different stream proporties.
|
||||
@torch.jit.script
|
||||
def test_get_stream():
|
||||
device_index = torch.cuda._current_device()
|
||||
current_stream = torch.cuda.current_stream(device_index)
|
||||
default_stream = torch.cuda.default_stream(device_index)
|
||||
user_stream = torch.jit.cuda.Stream(device_index, 0)
|
||||
|
||||
# Check if the current and default streams are the same on the device
|
||||
is_current_and_default_stream_same = current_stream.id() == default_stream.id()
|
||||
# Check if user stream and default stream are not the same on the device
|
||||
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
|
||||
|
||||
with torch.jit.cuda.stream(user_stream):
|
||||
is_stream_set = torch.cuda.current_stream(device_index).id() == user_stream.id()
|
||||
|
||||
# Check if the stream was reset to current_stream
|
||||
is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
|
||||
|
||||
tensor1 = torch.rand(10000, 10000, device="cuda")
|
||||
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
|
||||
default_stream.synchronize()
|
||||
default_stream_query = default_stream.query()
|
||||
|
||||
# Capture all the results in the class Result
|
||||
res = Result(
|
||||
tensor1, tensor2, is_current_and_default_stream_same,
|
||||
is_default_and_user_stream_not_same, is_stream_set,
|
||||
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
|
||||
return res
|
||||
|
||||
result = test_get_stream()
|
||||
|
||||
self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
|
||||
self.assertTrue(result.is_current_and_default_stream_same)
|
||||
self.assertTrue(result.is_default_and_user_stream_not_same)
|
||||
self.assertTrue(result.is_stream_set)
|
||||
self.assertTrue(result.is_stream_reset)
|
||||
self.assertTrue(result.default_stream_query)
|
||||
self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0
|
||||
self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero
|
||||
|
||||
# Test the stream context manager. This test checks if the stream is switched
|
||||
# to the user stream on using the stream context manager.
|
||||
@torch.jit.script
|
||||
def test_stream_context():
|
||||
device_index = torch.cuda._current_device()
|
||||
current_stream = torch.cuda.current_stream(device_index)
|
||||
user_stream = torch.jit.cuda.Stream(device_index, 0)
|
||||
A = torch.rand(1000, 1000, device="cuda")
|
||||
|
||||
with torch.jit.cuda.stream(user_stream):
|
||||
check = torch.cuda.current_stream(device_index).id() == user_stream.id()
|
||||
B = torch.mm(A, A).to("cuda")
|
||||
# Wait for B to be computed
|
||||
user_stream.synchronize()
|
||||
# Check if the stream has been reset on the current device
|
||||
is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
|
||||
|
||||
return A, B, check, is_stream_reset
|
||||
|
||||
A, B, is_stream_set, is_stream_reset = test_stream_context()
|
||||
self.assertEqual(torch.matmul(A, A), B)
|
||||
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
|
||||
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
|
||||
|
||||
# Test multiple nested streams. Check if the operations are computed as expected on the streams
|
||||
# This test has been adapted from the eager mode tests available at test/test_cuda.py
|
||||
@torch.jit.script
|
||||
def test_multiple_stream():
|
||||
prev_device_index = torch.cuda._current_device()
|
||||
prev_current_stream = torch.cuda.current_stream(prev_device_index)
|
||||
s1 = torch.jit.cuda.Stream(0, 0)
|
||||
s2 = torch.jit.cuda.Stream(1, 0)
|
||||
|
||||
A = torch.rand(1000, 1000, device="cuda")
|
||||
B = torch.rand(1000, 1000, device="cuda")
|
||||
with torch.jit.cuda.stream(s1):
|
||||
C = torch.mm(A, A).to("cuda")
|
||||
# Check if the stream and device have been set to s1
|
||||
is_stream_s1 = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
|
||||
is_device_s1 = torch.cuda._current_device() == s1.device_index()
|
||||
with torch.jit.cuda.stream(s2):
|
||||
# Check if the stream and device have been set to s2
|
||||
is_stream_s2 = torch.cuda.current_stream(s2.device_index()).id() == s2.id()
|
||||
is_device_s2 = torch.cuda._current_device() == s2.device_index()
|
||||
D = torch.mm(B, B).to("cuda")
|
||||
# Check if the stream and device have been set to s1
|
||||
is_stream_s1_after = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
|
||||
is_device_s1_after = torch.cuda._current_device() == s1.device_index()
|
||||
# Wait for D to be computed
|
||||
s2.synchronize()
|
||||
# Wait for C to be computed on S1
|
||||
s1.synchronize()
|
||||
|
||||
# Check if the stream and device has been restored to previous stream and device
|
||||
is_device_current = torch.cuda._current_device() == prev_device_index
|
||||
is_stream_current = torch.cuda.current_stream(prev_device_index).id() == prev_current_stream.id()
|
||||
|
||||
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
|
||||
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
|
||||
return A, B, C, D, check_stream, check_device
|
||||
A, B, C, D, check_stream, check_device = test_multiple_stream()
|
||||
|
||||
self.assertEqual(torch.matmul(A, A), C)
|
||||
self.assertEqual(torch.matmul(B, B), D)
|
||||
self.assertTrue(check_stream)
|
||||
self.assertTrue(check_device)
|
||||
|
||||
# Test multiple streams waiting on each other for the operations to be completed.
|
||||
@torch.jit.script
|
||||
def test_data_dependency_between_streams():
|
||||
device_index = torch.cuda._current_device()
|
||||
prev_current_stream = torch.cuda.current_stream(device_index)
|
||||
s1 = torch.jit.cuda.Stream(0, 0)
|
||||
s2 = torch.jit.cuda.Stream(0, 0)
|
||||
event = torch.jit.cuda.Event(False, False, False)
|
||||
|
||||
A = torch.rand(1000, 1000, device="cuda")
|
||||
with torch.jit.cuda.stream(s1):
|
||||
is_stream_s1 = torch.cuda.current_stream(device_index).id() == s1.id()
|
||||
B = torch.mm(A, A).to("cuda")
|
||||
s1.record_event(event)
|
||||
# Check if the current_stream is reset
|
||||
is_current_stream_1 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
|
||||
# Wait for ops on s1 to be computed
|
||||
s2.wait_event(event)
|
||||
with torch.jit.cuda.stream(s2):
|
||||
is_stream_s2 = torch.cuda.current_stream(device_index).id() == s2.id()
|
||||
C = torch.mm(B, B).to("cuda")
|
||||
# Wait for C to be computed
|
||||
s2.synchronize()
|
||||
# Check if the current_stream is reset
|
||||
is_current_stream_2 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
|
||||
|
||||
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
|
||||
return A, B, C, check_stream
|
||||
|
||||
A, B, C, check_stream = test_data_dependency_between_streams()
|
||||
self.assertEqual(torch.matmul(A, A), B)
|
||||
self.assertEqual(torch.matmul(B, B), C)
|
||||
self.assertTrue(check_stream)
|
||||
|
||||
# Test a simple CUDA event. Test if the CUDA event was created successfully
|
||||
@torch.jit.script
|
||||
def test_simple_event():
|
||||
e = torch.jit.cuda.Event(True, False, False)
|
||||
return e is not None
|
||||
self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
|
||||
|
||||
# Record the CUDA event for operation torch.mm on the current stream
|
||||
# and then test if the elapsed time is greater than 0. This test is also
|
||||
# an adaption from eager mdoe CUDA tests available at test/test_cuda.py
|
||||
@torch.jit.script
|
||||
def test_event():
|
||||
device_index = torch.cuda._current_device()
|
||||
stream = torch.cuda.current_stream(device_index)
|
||||
event = torch.jit.cuda.Event(True, False, False)
|
||||
is_true_event_query = event.query()
|
||||
start_event = torch.jit.cuda.Event(True, False, False)
|
||||
stream.record_event(start_event)
|
||||
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
|
||||
stream.record_event(event)
|
||||
event.synchronize()
|
||||
is_again_true_event_query = event.query()
|
||||
|
||||
if not (is_true_event_query and is_again_true_event_query):
|
||||
return -1.0
|
||||
return start_event.elapsed_time(event)
|
||||
|
||||
self.assertGreater(test_event(), 0)
|
||||
|
||||
# Check for stream synchronization , when a large tensor multiplication is
|
||||
# computed on the stream. The stream.query should be true once the synchroniztion is done
|
||||
@torch.jit.script
|
||||
def test_stream_synchronize() -> float:
|
||||
device_index = torch.cuda._current_device()
|
||||
s = torch.jit.cuda.Stream(device_index, 0)
|
||||
e_tik = torch.jit.cuda.Event(True, False, False)
|
||||
e_tok = torch.jit.cuda.Event(True, False, False)
|
||||
|
||||
e_tik.record(s)
|
||||
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
with torch.jit.cuda.stream(s):
|
||||
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
|
||||
s.synchronize()
|
||||
e_tok.record(s)
|
||||
e_tok.synchronize()
|
||||
|
||||
if not s.query():
|
||||
return -1.0
|
||||
|
||||
# not necessary to check e_tik and e_tok, as elapsed_time would throw
|
||||
# exception if otherwise.
|
||||
return e_tik.elapsed_time(e_tok)
|
||||
self.assertGreater(test_stream_synchronize(), 0)
|
||||
|
||||
# Test event synchronization for the event that records a stream doing
|
||||
# a large tensor multiplication. Check if the elapsed time is greater than 0
|
||||
# and the stream.query evaluates to true.
|
||||
@torch.jit.script
|
||||
def test_event_synchronize() -> float:
|
||||
device_index = torch.cuda._current_device()
|
||||
s = torch.jit.cuda.Stream(device_index, 0)
|
||||
e_tik = torch.jit.cuda.Event(True, False, False)
|
||||
e_tok = torch.jit.cuda.Event(True, False, False)
|
||||
|
||||
e_tik.record(s)
|
||||
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
with torch.jit.cuda.stream(s):
|
||||
tensor = torch.mm(tensor1, tensor1).to("cuda")
|
||||
s.record_event(e_tok)
|
||||
e_tok.synchronize()
|
||||
s.synchronize()
|
||||
|
||||
if not s.query():
|
||||
return -1.0
|
||||
|
||||
# not necessary to check e_tik and e_tok, as elapsed_time would throw
|
||||
# exception if otherwise.
|
||||
return e_tik.elapsed_time(e_tok)
|
||||
|
||||
self.assertGreater(test_event_synchronize(), 0)
|
||||
|
||||
# Test for event wait. Check if event waits for the all the operations on
|
||||
# the stream to be done. Check for synchronizations and query on the streams
|
||||
# and events. This test is adapted from eager mode tests for CUDA. Please refer
|
||||
# test/test_cuda.py
|
||||
@torch.jit.script
|
||||
def test_event_wait() -> float:
|
||||
device_index = torch.cuda._current_device()
|
||||
s0 = torch.cuda.current_stream(device_index)
|
||||
s1 = torch.jit.cuda.Stream(device_index, 0)
|
||||
e_tik = torch.jit.cuda.Event(True, True, False)
|
||||
e_tok = torch.jit.cuda.Event(True, True, False)
|
||||
|
||||
e_tik.record(s0)
|
||||
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
with torch.jit.cuda.stream(s0):
|
||||
tensor2 = torch.mm(tensor1, tensor1).cuda()
|
||||
e_sync = torch.jit.cuda.Event(True, False, False)
|
||||
e_sync.record(torch.cuda.current_stream(device_index))
|
||||
e_sync.wait(s1)
|
||||
with torch.jit.cuda.stream(s1):
|
||||
tensor3 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
tensor4 = torch.mm(tensor3, tensor3).cuda()
|
||||
s1.synchronize()
|
||||
e_tok.record(torch.cuda.current_stream(device_index))
|
||||
e_tok.synchronize()
|
||||
s0.synchronize()
|
||||
|
||||
if not s0.query() or not s1.query() or not e_sync.query():
|
||||
return -1.0
|
||||
|
||||
# not necessary to check e_tik and e_tok, as elapsed_time would throw
|
||||
# exception if otherwise.
|
||||
return e_tik.elapsed_time(e_tok)
|
||||
self.assertGreater(test_event_wait(), 0)
|
||||
|
||||
# Test for stream wait_event. Checks if the stream waits on the event
|
||||
@torch.jit.script
|
||||
def test_wait_event():
|
||||
d1 = torch.device('cuda:1')
|
||||
|
||||
with torch.jit.cuda.device(d1):
|
||||
s0 = torch.cuda.current_stream(1)
|
||||
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
|
||||
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
|
||||
e0 = torch.jit.cuda.Event(False, False, False)
|
||||
s0.record_event(e0)
|
||||
|
||||
s1 = torch.cuda.current_stream(0)
|
||||
s1.wait_event(e0)
|
||||
s1.synchronize()
|
||||
|
||||
return e0.query() and s0.query() and s1.query()
|
||||
self.assertTrue(test_wait_event())
|
||||
|
||||
# Test if a scripted module with cuda streams can be saved, loaded and executed
|
||||
def test_save_load(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self):
|
||||
device_index = torch.cuda._current_device()
|
||||
s = torch.jit.cuda.Stream(device_index, 0)
|
||||
a = torch.rand(3, 4, device="cuda")
|
||||
b = torch.rand(3, 4, device="cuda")
|
||||
|
||||
with torch.jit.cuda.stream(s):
|
||||
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
|
||||
c = torch.cat((a, b), 0).cuda()
|
||||
s.synchronize()
|
||||
return is_stream_s, a, b, c
|
||||
|
||||
model = Model()
|
||||
|
||||
# Script the model and save
|
||||
script_model = torch.jit.script(model)
|
||||
is_stream_s, a, b, c = script_model()
|
||||
# Verify if the output is correct
|
||||
self.assertTrue(is_stream_s)
|
||||
self.assertEqual(torch.cat((a, b), 0), c)
|
||||
|
||||
# Save and load scripted model
|
||||
load_model = self.getExportImportCopy(script_model)
|
||||
is_stream_s, a_load, b_load, c_load = load_model()
|
||||
self.assertTrue(is_stream_s)
|
||||
self.assertEqual(torch.cat((a_load, b_load), 0), c_load)
|
||||
@ -35,6 +35,7 @@ from jit.test_profiler import TestProfiler # noqa: F401
|
||||
from jit.test_slice import TestSlice # noqa: F401
|
||||
from jit.test_warn import TestWarn # noqa: F401
|
||||
from jit.test_isinstance import TestIsinstance # noqa: F401
|
||||
from jit.test_cuda import TestCUDA # noqa: F401
|
||||
from jit.test_hash import TestHash # noqa: F401
|
||||
|
||||
# Torch
|
||||
|
||||
@ -408,6 +408,7 @@ libtorch_cuda_core_sources = [
|
||||
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
|
||||
"torch/csrc/jit/codegen/cuda/type.cpp",
|
||||
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
|
||||
"torch/csrc/jit/runtime/register_cuda_ops.cpp",
|
||||
]
|
||||
|
||||
libtorch_cuda_sources = libtorch_cuda_core_sources + [
|
||||
|
||||
179
torch/csrc/jit/cuda/cuda.h
Normal file
179
torch/csrc/jit/cuda/cuda.h
Normal file
@ -0,0 +1,179 @@
|
||||
#include <aten/src/ATen/cuda/CUDAEvent.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
class CUDAEvent;
|
||||
// This class is a wrapper around c10::cuda::CUDAStream.
|
||||
// It is needed because TorchBind does not support all of the argument types
|
||||
// for c10::cuda::CUDAStream. For more details, please refer to
|
||||
// c10/cuda/CUDAStream.h.
|
||||
class CUDAStream final : public CustomClassHolder {
|
||||
public:
|
||||
CUDAStream(int64_t device = -1, int64_t priority = 0) {
|
||||
constexpr int64_t PRIORITY_INDEX = 0;
|
||||
stream_ = std::make_unique<c10::cuda::CUDAStream>(
|
||||
c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device));
|
||||
}
|
||||
|
||||
CUDAStream(c10::cuda::CUDAStream s) {
|
||||
stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
|
||||
}
|
||||
|
||||
bool query() {
|
||||
return stream_->query();
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<CUDAEvent> recordEvent(
|
||||
c10::intrusive_ptr<CUDAEvent> event);
|
||||
|
||||
void synchronize() {
|
||||
stream_->synchronize();
|
||||
}
|
||||
|
||||
void waitEvent(c10::intrusive_ptr<CUDAEvent> event);
|
||||
|
||||
void waitStream(c10::intrusive_ptr<CUDAStream> stream);
|
||||
|
||||
/// Get the CUDA device index that this stream is associated with.
|
||||
int64_t device_index() const {
|
||||
return stream_->device_index();
|
||||
}
|
||||
|
||||
/// Get the full Device that this stream is associated with. The Device
|
||||
/// is guaranteed to be a CUDA device.
|
||||
c10::Device device() const {
|
||||
return stream_->device();
|
||||
}
|
||||
|
||||
/// Return the stream ID corresponding to this particular stream.
|
||||
int64_t id() const {
|
||||
return stream_->id();
|
||||
}
|
||||
|
||||
/// Pack a CUDAStream to uint64_t representation.
|
||||
/// The CUDAStream can be unpacked using unpack(). The format of
|
||||
/// the uint64_t is unspecified and may be changed.
|
||||
int64_t pack() const {
|
||||
return stream_->pack();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<c10::cuda::CUDAStream> stream_;
|
||||
friend class CUDAEvent;
|
||||
};
|
||||
|
||||
// This class is a wrapper around at::cuda::CUDAStream.
|
||||
// It is needed because TorchBind does not support all of the argument types
|
||||
// for at::cuda::CUDAEvent. For more details, please refer to
|
||||
// aten/src/ATen/cuda/CUDAEvent.h.
|
||||
class CUDAEvent final : public CustomClassHolder {
|
||||
public:
|
||||
CUDAEvent(
|
||||
bool enable_timing = false,
|
||||
bool blocking = false,
|
||||
bool interprocess = false) {
|
||||
int flags = cudaEventDisableTiming;
|
||||
if (enable_timing) {
|
||||
flags = cudaEventDefault;
|
||||
}
|
||||
if (blocking) {
|
||||
flags |= cudaEventBlockingSync;
|
||||
}
|
||||
if (interprocess) {
|
||||
TORCH_CHECK(!enable_timing);
|
||||
flags |= cudaEventInterprocess;
|
||||
}
|
||||
|
||||
event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
|
||||
}
|
||||
|
||||
double elapsedTime(c10::intrusive_ptr<CUDAEvent> end) {
|
||||
return event_->elapsed_time(*end->event_);
|
||||
}
|
||||
|
||||
std::string ipcHandle() {
|
||||
cudaIpcEventHandle_t handle;
|
||||
event_->ipc_handle(&handle);
|
||||
std::string str_handle((const char*)&handle, sizeof(handle));
|
||||
return str_handle;
|
||||
}
|
||||
|
||||
bool query() {
|
||||
return event_->query();
|
||||
}
|
||||
|
||||
void record(c10::intrusive_ptr<CUDAStream> stream);
|
||||
|
||||
void synchronize() {
|
||||
event_->synchronize();
|
||||
}
|
||||
void wait(c10::intrusive_ptr<CUDAStream> stream);
|
||||
|
||||
private:
|
||||
void recordInternal(CUDAStream* stream);
|
||||
std::unique_ptr<at::cuda::CUDAEvent> event_;
|
||||
|
||||
friend class CUDAStream;
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
|
||||
c10::intrusive_ptr<CUDAEvent> event) {
|
||||
if (!event) {
|
||||
event = c10::make_intrusive<CUDAEvent>();
|
||||
}
|
||||
|
||||
event->recordInternal(this);
|
||||
return event;
|
||||
}
|
||||
|
||||
void CUDAStream::waitEvent(c10::intrusive_ptr<CUDAEvent> event) {
|
||||
event->event_->block(*stream_);
|
||||
}
|
||||
|
||||
void CUDAStream::waitStream(c10::intrusive_ptr<CUDAStream> stream) {
|
||||
auto ev = c10::make_intrusive<CUDAEvent>();
|
||||
stream->recordEvent(ev);
|
||||
waitEvent(ev);
|
||||
}
|
||||
|
||||
void CUDAEvent::record(c10::intrusive_ptr<CUDAStream> stream) {
|
||||
event_->record(*stream->stream_);
|
||||
}
|
||||
|
||||
void CUDAEvent::recordInternal(CUDAStream* stream) {
|
||||
event_->record(*stream->stream_);
|
||||
}
|
||||
|
||||
void CUDAEvent::wait(c10::intrusive_ptr<CUDAStream> stream) {
|
||||
event_->block(*stream->stream_);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(cuda, m) {
|
||||
auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
|
||||
torch::init<int64_t, int64_t>());
|
||||
auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
|
||||
torch::init<bool, bool, bool>());
|
||||
|
||||
stream_class.def("query", &CUDAStream::query)
|
||||
.def("record_event", &CUDAStream::recordEvent)
|
||||
.def("synchronize", &CUDAStream::synchronize)
|
||||
.def("wait_event", &CUDAStream::waitEvent)
|
||||
.def("wait_stream", &CUDAStream::waitStream)
|
||||
.def("device_index", &CUDAStream::device_index)
|
||||
.def("device", &CUDAStream::device)
|
||||
.def("pack", &CUDAStream::pack)
|
||||
.def("id", &CUDAStream::id);
|
||||
|
||||
event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
|
||||
.def("query", &CUDAEvent::query)
|
||||
.def("record", &CUDAEvent::record)
|
||||
.def("synchronize", &CUDAEvent::synchronize)
|
||||
.def("wait", &CUDAEvent::wait);
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
@ -211,6 +211,13 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the type is a custom class. This is done by checking
|
||||
// if type_name starts with "torch.classes."
|
||||
if (type_name.find("torch.classes.") == 0) {
|
||||
auto custom_class_type = getCustomClass("__torch__." + type_name);
|
||||
return custom_class_type;
|
||||
}
|
||||
|
||||
throw ErrorReport(expr) << "Unknown type name '" << type_name << "'";
|
||||
} else if (auto name = parseBaseTypeName(expr)) {
|
||||
auto itr = string_to_type_lut().find(*name);
|
||||
|
||||
@ -572,7 +572,8 @@ void AliasDb::analyzeImpl(Node* node) {
|
||||
!aliasAnalysisHasSpecialCaseFor(node->kind()),
|
||||
"Special cases should be handled already if we're here.");
|
||||
|
||||
if (node->kind().is_aten() || node->kind().is_prim()) {
|
||||
if (node->kind().is_aten() || node->kind().is_prim() ||
|
||||
node->kind().is_cuda()) {
|
||||
// TODO There is nothing in the system that relies on aten:: and prim::
|
||||
// ops using AliasAnalysisKind::FROM_SCHEMA or
|
||||
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
|
||||
|
||||
@ -1079,6 +1079,11 @@ bool Node::hasSideEffects() const {
|
||||
case prim::rpc_sync: // It represents RPC message sent.
|
||||
case prim::rpc_remote: // It represents RPC message sent.
|
||||
case aten::wait: // It can represent RPC message received.
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
case cuda::set_stream:
|
||||
case cuda::_set_device:
|
||||
case cuda::_current_device:
|
||||
#endif
|
||||
case prim::Enter:
|
||||
case prim::Exit:
|
||||
return true;
|
||||
@ -1094,7 +1099,7 @@ bool Node::hasSideEffects() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kind_.is_prim() || kind_.is_aten()) {
|
||||
if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
|
||||
// TODO There is nothing in the system that relies on aten:: and prim::
|
||||
// ops using AliasAnalysisKind::FROM_SCHEMA,
|
||||
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
|
||||
|
||||
@ -72,6 +72,11 @@ using namespace ::c10::attr;
|
||||
namespace aten {
|
||||
using namespace ::c10::aten;
|
||||
}
|
||||
namespace cuda {
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
using namespace ::c10::cuda;
|
||||
#endif
|
||||
} // namespace cuda
|
||||
|
||||
struct Function;
|
||||
struct MatchedSchema;
|
||||
|
||||
@ -217,6 +217,32 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
|
||||
return toSugaredValue(member, m, loc, /*is_constant=*/true);
|
||||
}
|
||||
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
const std::string& field) {
|
||||
// List of all the cuda operators which are supported in JIT
|
||||
const std::unordered_set<std::string> cuda_ops = {"current_stream",
|
||||
"default_stream",
|
||||
"_current_device",
|
||||
"_set_device",
|
||||
"device_index",
|
||||
"device_count",
|
||||
"set_stream"};
|
||||
|
||||
if (cuda_ops.find(field) != cuda_ops.end()) {
|
||||
return std::make_shared<BuiltinFunction>(Symbol::cuda(field), c10::nullopt);
|
||||
}
|
||||
|
||||
py::object member = getattr(loc, field);
|
||||
// note: is_constant = true because we consider that global properties
|
||||
// on modules like math.pi or torch.float to be constants
|
||||
// even though it is possible, though rare, for someone to mutate them
|
||||
return toSugaredValue(member, m, loc, /*is_constant=*/true);
|
||||
}
|
||||
#endif
|
||||
|
||||
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
|
||||
return self_;
|
||||
}
|
||||
@ -938,6 +964,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||
if (auto callee = as_function(obj)) {
|
||||
return std::make_shared<FunctionValue>(callee->function_);
|
||||
} else if (py::isinstance<py::module>(obj)) {
|
||||
#ifndef USE_ROCM
|
||||
std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
|
||||
if (obj_name.compare("torch.cuda") == 0) {
|
||||
return std::make_shared<CUDAPythonModuleValue>(obj);
|
||||
}
|
||||
#endif
|
||||
return std::make_shared<PythonModuleValue>(obj);
|
||||
} else if (
|
||||
obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||
|
||||
|
||||
@ -91,6 +91,20 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
|
||||
const std::string& field) override;
|
||||
};
|
||||
|
||||
// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
|
||||
// torch.cuda.* are resolved using CUDAPythonModuleValue.
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
|
||||
explicit CUDAPythonModuleValue(py::object mod)
|
||||
: PythonValue(std::move(mod)) {}
|
||||
|
||||
std::shared_ptr<SugaredValue> attr(
|
||||
const SourceRange& loc,
|
||||
Function& m,
|
||||
const std::string& field) override;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Represents all the parameters of a module as a List[Tensor]
|
||||
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
|
||||
ConstantParameterList(Value* the_list) : the_list_(the_list) {}
|
||||
|
||||
87
torch/csrc/jit/runtime/register_cuda_ops.cpp
Normal file
87
torch/csrc/jit/runtime/register_cuda_ops.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
// This file registers special JIT operators used to implement the PyTorch CUDA
|
||||
// API in TorchScript.
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
#include <torch/csrc/api/include/torch/utils.h>
|
||||
#include <torch/csrc/jit/cuda/cuda.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
|
||||
return c10::AliasAnalysisKind::FROM_SCHEMA;
|
||||
}
|
||||
|
||||
RegisterOperators const reg({
|
||||
Operator(
|
||||
"cuda::current_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
|
||||
[](Stack* stack) {
|
||||
auto idx = uint16_t(pop(stack).toInt());
|
||||
auto s = c10::cuda::getCurrentCUDAStream(idx);
|
||||
auto st = make_custom_class<torch::jit::CUDAStream>(s);
|
||||
push(stack, IValue(st));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::default_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
|
||||
[](Stack* stack) {
|
||||
auto idx = uint16_t(pop(stack).toInt());
|
||||
auto s = c10::cuda::getDefaultCUDAStream(idx);
|
||||
auto st = make_custom_class<torch::jit::CUDAStream>(s);
|
||||
push(stack, IValue(st));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::_current_device() -> int",
|
||||
[](Stack* stack) {
|
||||
auto v = c10::cuda::current_device();
|
||||
push(stack, static_cast<int>(v));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::_set_device(int64_t val) -> ()",
|
||||
[](Stack* stack) {
|
||||
int64_t idx = -1;
|
||||
pop(stack, idx);
|
||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(idx));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::device_index(Device device) -> int",
|
||||
[](Stack* stack) {
|
||||
auto device = pop(stack);
|
||||
auto idx = device.toDevice().index();
|
||||
push(stack, idx);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::device_count() -> int",
|
||||
[](Stack* stack) { push(stack, at::cuda::device_count()); },
|
||||
aliasAnalysisFromSchema()),
|
||||
Operator(
|
||||
"cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()",
|
||||
[](Stack* stack) {
|
||||
auto v = pop(stack);
|
||||
auto s = v.toCustomClass<torch::jit::CUDAStream>();
|
||||
// To set the current CUDA stream using
|
||||
// c10::cuda::setCurrentCUDAStream, the jit::CUDAStream object needs
|
||||
// to be converted to c10::cuda::CUDAStream. Since the latter cannot
|
||||
// be returned from a class registered via TorchBind, this can only be
|
||||
// achieved by packing the c10::cuda::CUDAStream instance contained
|
||||
// inside the jit::CUDAStream object to a uint64_t representation, and
|
||||
// unpacking it inside this operator. The unpacked stream is then used
|
||||
// to set the current CUDA stream.
|
||||
auto packed = s->pack();
|
||||
auto unpacked = c10::cuda::CUDAStream::unpack(packed);
|
||||
c10::cuda::setCurrentCUDAStream(unpacked);
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
});
|
||||
} // namespace
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
#endif
|
||||
@ -44,6 +44,7 @@ from torch.jit._async import fork, wait
|
||||
from torch.jit._serialization import save, load
|
||||
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
|
||||
|
||||
from torch.jit.cuda import stream
|
||||
from torch.jit._freeze import freeze
|
||||
|
||||
# For backwards compatibility
|
||||
|
||||
182
torch/jit/cuda.py
Normal file
182
torch/jit/cuda.py
Normal file
@ -0,0 +1,182 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
This package adds support for JIT compilation for CUDA Streams and events,
|
||||
This is similar to API's available in the eager mode
|
||||
:ref:`cuda-semantics` has more details about working with CUDA.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Optional, Any
|
||||
from torch import device as _device
|
||||
|
||||
def get_current_device_index() -> int:
|
||||
r"""Checks if there are CUDA devices available and
|
||||
returns the device index of the current default CUDA device.
|
||||
Returns -1 in case there are no CUDA devices available.
|
||||
|
||||
Arguments: ``None``
|
||||
"""
|
||||
if torch.cuda.device_count() > 0:
|
||||
return torch.cuda._current_device()
|
||||
return -1
|
||||
|
||||
def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int:
|
||||
r"""Gets the device index from :attr:`device`, which can be a torch.device
|
||||
object, a Python integer, or ``None``.
|
||||
|
||||
If :attr:`device` is a torch.device object, returns the device index if it
|
||||
is a CUDA device. Note that for a CUDA device without a specified index,
|
||||
, this will return the current default CUDA device if :attr:`optional` is ``True``.
|
||||
If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be
|
||||
returned in this case.
|
||||
|
||||
If :attr:`device` is a Python integer, it is returned as is.
|
||||
|
||||
If :attr:`device` is ``None``, this will return the current default CUDA
|
||||
device if :attr:`optional` is ``True``.
|
||||
"""
|
||||
if device is None:
|
||||
if optional:
|
||||
return get_current_device_index()
|
||||
else:
|
||||
raise ValueError('Expected a torch.device with a specified index '
|
||||
f'or an integer, but got: {device}')
|
||||
device_index = -1
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
if isinstance(device, torch.device):
|
||||
if not allow_cpu and device.type == 'cpu':
|
||||
raise ValueError(f'Expected a non cpu device, but got: {device}')
|
||||
device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device)
|
||||
|
||||
if isinstance(device, int):
|
||||
device_index = device
|
||||
|
||||
return device_index
|
||||
|
||||
class device(object):
|
||||
r"""Context-manager that changes the selected device.
|
||||
This is similar to device (torch.device or int), but has been
|
||||
introduced for JIT compatibility.
|
||||
Arguments:
|
||||
device (torch.device or int): device index to select. It's a no-op if
|
||||
this argument is a negative integer or ``None``.
|
||||
"""
|
||||
def __init__(self, device: Optional[_device]):
|
||||
self.idx = -1
|
||||
self.prev_idx = -1
|
||||
self.device = device
|
||||
|
||||
def __enter__(self):
|
||||
self.idx = get_device_index(self.device, optional=True)
|
||||
|
||||
if self.idx == -1:
|
||||
return
|
||||
self.prev_idx = torch.cuda._current_device()
|
||||
|
||||
if self.prev_idx != self.idx:
|
||||
torch.cuda._set_device(self.idx)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
if self.prev_idx != self.idx:
|
||||
torch.cuda._set_device(self.prev_idx)
|
||||
|
||||
class StreamContext(object):
|
||||
r"""Context-manager that selects a given stream.
|
||||
All CUDA kernels queued within its context will be enqueued on a selected
|
||||
stream.
|
||||
Arguments:
|
||||
StreamContext (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
.. note:: Streams are per-device. If the selected stream is not on the
|
||||
current device, this function will also change the current device to
|
||||
match the stream.
|
||||
"""
|
||||
cur_stream : Optional['torch.classes.cuda.Stream']
|
||||
|
||||
def __init__(self, stream: Optional['torch.classes.cuda.Stream']):
|
||||
self.idx = -1
|
||||
self.stream = stream
|
||||
# Initialize the below streams to default stream on the current device
|
||||
self.device_index = get_current_device_index()
|
||||
self.src_prev_stream = torch.cuda.default_stream(self.device_index)
|
||||
self.dst_prev_stream = torch.cuda.default_stream(self.device_index)
|
||||
|
||||
def __enter__(self):
|
||||
self.idx = get_device_index(device=None, optional=True)
|
||||
# If there is no CUDA device available, return
|
||||
if self.idx == -1:
|
||||
return
|
||||
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# Return if stream is None
|
||||
if cur_stream is None:
|
||||
return
|
||||
self.src_prev_stream = torch.cuda.current_stream(self.idx)
|
||||
# If the stream is not on the current device, then change the device
|
||||
# and set the current stream on the device
|
||||
if self.src_prev_stream.device_index() != cur_stream.device_index():
|
||||
with device(cur_stream.device()):
|
||||
self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index())
|
||||
torch.cuda._set_device(cur_stream.device_index())
|
||||
torch.cuda.set_stream(cur_stream)
|
||||
|
||||
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||
# Local cur_stream variable for type refinement
|
||||
cur_stream = self.stream
|
||||
# If stream is None or no CUDA device available, return
|
||||
if cur_stream is None or self.idx == -1:
|
||||
return
|
||||
# If the stream was not on the current device, restore the previous stream on
|
||||
# the destination device and also reset the current device to the previous device.
|
||||
# Set the current stream on the device to the src_prev_stream
|
||||
if self.src_prev_stream.device_index() != cur_stream.device_index():
|
||||
torch.cuda.set_stream(self.dst_prev_stream)
|
||||
torch.cuda._set_device(self.idx)
|
||||
torch.cuda.set_stream(self.src_prev_stream)
|
||||
|
||||
def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext:
|
||||
r"""Wrapper around the Context-manager that selects a given stream.
|
||||
All CUDA kernels queued within its context will be enqueued on a selected
|
||||
stream.
|
||||
Arguments:
|
||||
stream (Stream): selected stream. This manager is a no-op if it's
|
||||
``None``.
|
||||
"""
|
||||
return StreamContext(stream)
|
||||
|
||||
def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream':
|
||||
r"""Wrapper around a CUDA stream.
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific
|
||||
device, independent from other streams. See :ref:`cuda-semantics` for
|
||||
details.
|
||||
Arguments:
|
||||
device(int, optional): a device on which to allocate
|
||||
the stream. If :attr:`device` is ``None`` (default) or a negative
|
||||
integer, this will use the current device.
|
||||
priority(int, optional): priority of the stream. Can be either
|
||||
-1 (high priority) or 0 (low priority). By default, streams have
|
||||
priority 0.
|
||||
.. note:: Although CUDA versions >= 11 support more than two levels of
|
||||
priorities, in PyTorch, we only support two levels of priorities.
|
||||
"""
|
||||
return torch.classes.cuda.Stream(device, priority)
|
||||
|
||||
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event':
|
||||
r"""Wrapper around a CUDA event.
|
||||
CUDA events are synchronization markers that can be used to monitor the
|
||||
device's progress, to accurately measure timing, and to synchronize CUDA
|
||||
streams.
|
||||
Arguments:
|
||||
enable_timing (bool, optional): indicates if the event should measure time
|
||||
(default: ``False``)
|
||||
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
|
||||
interprocess (bool): if ``True``, the event can be shared between processes
|
||||
(default: ``False``)
|
||||
.. _CUDA Event Documentation:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
|
||||
"""
|
||||
return torch.classes.cuda.Event(enable_timing, blocking, interprocess)
|
||||
Reference in New Issue
Block a user