Adding JIT support for cuda streams and events (#48020)

Summary:
=======

This PR addresses the following:

 * Adds JIT support for CUDA Streams
 * Adds JIT support for CUDA Events
 * Adds JIT support for CUDA Stream context manager

Testing:
======

python test/test_jit.py -v TestCUDA

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48020

Reviewed By: navahgar

Differential Revision: D25725749

Pulled By: nikithamalgifb

fbshipit-source-id: b0addeb49630f8f0c430ed7badeca43bb9d2535c
This commit is contained in:
Nikitha Malgi
2020-12-29 20:22:19 -08:00
committed by Facebook GitHub Bot
parent 97c17b4772
commit 12b73fdbbf
16 changed files with 1057 additions and 2 deletions

View File

@ -17,6 +17,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
@ -284,6 +285,9 @@ namespace c10 {
_(aten, zero_) \
_(aten, fill_) \
_(aten, masked_fill_) \
_(cuda, _set_device) \
_(cuda, set_stream) \
_(cuda, _current_device) \
_(aten, swapaxes) \
_(aten, swapaxes_) \
_(aten, swapdims) \
@ -383,6 +387,7 @@ namespace c10 {
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, aten) \
_(namespaces, cuda) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
@ -453,6 +458,7 @@ struct TORCH_API Symbol {
// (and if it's not, you should add it to the built-ins list above.)
static Symbol attr(const std::string & s);
static Symbol aten(const std::string & s);
static Symbol cuda(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
static Symbol user(const std::string & s);
@ -463,6 +469,7 @@ struct TORCH_API Symbol {
bool is_attr() const;
bool is_aten() const;
bool is_cuda() const;
bool is_prim() const;
bool is_onnx() const;
bool is_user() const;
@ -523,6 +530,7 @@ FORALL_NS_SYMBOLS(DEFINE_SYMBOL)
inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
inline Symbol Symbol::aten(const std::string & s) { return Symbol::fromQualString("aten::" + s); }
inline Symbol Symbol::cuda(const std::string & s) { return Symbol::fromQualString("cuda::" + s); }
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
@ -531,6 +539,7 @@ inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualStr
inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
inline bool Symbol::is_user() const { return ns() == namespaces::user; }

View File

@ -120,5 +120,33 @@ TEST(SerializationTest, TypeTags) {
}
}
TEST(SerializationTest, TestJitStream_CUDA) {
torch::jit::Module model;
std::vector<torch::jit::IValue> inputs;
// Deserialize the ScriptModule from a file using torch::jit::load().
// Load the scripted model. This should have been generated by tests_setup.py
// Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
model = torch::jit::load("saved_stream_model.pt");
auto output = model.forward(inputs);
auto list_of_elements = output.toTuple()->elements();
auto is_stream_s = list_of_elements[0].toBool();
// a,b: These are the two input tensors
// c: This is output tensor generated by the operation torch.cat(a,b)
auto a = list_of_elements[1].toTensor();
auto b = list_of_elements[2].toTensor();
auto c = list_of_elements[3].toTensor();
// op: this is used to verify if the cat operation produced the same results
// as that on the GPU with torch.cat
auto op = at::cat({a, b}, 0);
// Check if the stream is set
ASSERT_TRUE(is_stream_s);
// Check if the sizes of the outputs (op and c) is same on the GPU and CPU
ASSERT_EQ(op.sizes(), c.sizes());
// Check if both the output tensors are equal
ASSERT_TRUE(op.equal(c));
}
} // namespace jit
} // namespace torch

View File

@ -63,11 +63,38 @@ class TorchSaveError(FileSetup):
torch.save(value, self.path, _use_new_zipfile_serialization=False)
class TorchSaveJitStream_CUDA(FileSetup):
path = 'saved_stream_model.pt'
def setup(self):
if not torch.cuda.is_available():
return
class Model(torch.nn.Module):
def forward(self):
device_index = torch.cuda._current_device()
s = torch.jit.cuda.Stream(device_index, 0)
a = torch.rand(3, 4, device="cuda")
b = torch.rand(3, 4, device="cuda")
with torch.jit.cuda.stream(s):
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
c = torch.cat((a, b), 0).to("cuda")
s.synchronize()
return is_stream_s, a, b, c
model = Model()
# Script the model and save
script_model = torch.jit.script(model)
torch.jit.save(script_model, self.path)
tests = [
EvalModeForLoadedModule(),
SerializationInterop(),
TorchSaveError(),
TorchSaveJitStream_CUDA()
]
def setup():

476
test/jit/test_cuda.py Normal file
View File

@ -0,0 +1,476 @@
import os
import sys
import gc
import unittest
import torch
from typing import NamedTuple
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
# Check if GPU is available
TEST_CUDA = torch.cuda.is_available()
# Check if multiple GPU's are available
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
# If GPU is not available, then do not run the tests
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
JitTestCase = object # noqa: F811
TEST_LARGE_TENSOR = TEST_CUDA
# If GPU is available, then initialize the cuda context and check
# if there is memory available to allocate for LARGE Tensors.
if TEST_CUDA:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestCUDA(JitTestCase):
"""
A suite of tests for the CUDA API in TorchScript.
"""
def setUp(self):
super(TestCUDA, self).setUp()
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
super(TestCUDA, self).tearDown()
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
def test_current_stream(self):
# Test current stream on the device and check if the stream device index
# matches with the device ID
@torch.jit.script
def fn():
device_index = torch.cuda._current_device()
s0 = torch.cuda.current_stream(device_index)
s1 = torch.cuda.current_stream(1)
s2 = torch.cuda.current_stream(0)
return s0.device_index(), s1.device_index(), s2.device_index()
d0, d1, d2 = fn()
# By default, the current device ID is 0.
self.assertEqual(0, d0)
self.assertEqual(1, d1)
self.assertEqual(0, d2)
self.assertEqual(d0, d2)
@skipIfRocm
@unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@skipCUDANonDefaultStreamIf(True)
def test_streams_and_events(self):
# This test checks for the default stream ID is set to 0 on the device
@torch.jit.script
def test_default_streams():
s0 = torch.cuda.default_stream(0)
s1 = torch.cuda.default_stream(1)
d = torch.device('cuda:1')
# Check the current stream id and default id are same
# on the current device. The current device id by default is 0
s2 = torch.cuda.current_stream(0)
check_s2 = s2.id() == s0.id()
check_d0 = torch.cuda._current_device() == s2.device_index()
# Set the current device to d1 and check if the stream
# has been set to the default stream on d1
with torch.jit.cuda.device(d):
s3 = torch.cuda.current_stream(1)
check_s3 = s3.id() == s1.id()
check_d1 = torch.cuda._current_device() == s3.device_index()
# Check if the current device was reset to 0
is_device_d0 = torch.cuda._current_device() == s2.device_index()
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
self.assertEqual(d0, 0)
self.assertEqual(d1, 1)
self.assertTrue(check_s2)
self.assertTrue(check_s3)
self.assertTrue(check_d0)
self.assertTrue(check_d1)
self.assertTrue(is_device_d0)
# This test checks if the Stream Context manager is a no op
# when the stream is none for `with torch.jit.cuda.stream`
@torch.jit.script
def test_set_none_stream():
device_index = torch.cuda._current_device()
current_stream = torch.cuda.current_stream(device_index)
default_stream = torch.cuda.default_stream(device_index)
# When stream is none, check if this operation is a no-op
with torch.jit.cuda.stream(None):
cur_device_index = torch.cuda._current_device()
is_device_index_same = cur_device_index == device_index
is_current_stream_same = torch.cuda.current_stream(cur_device_index).id() == current_stream.id()
is_default_stream_same = torch.cuda.default_stream(device_index).id() == default_stream.id()
# Check if the device index, current stream and default streams have not changed
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
return are_streams_same
self.assertTrue(test_set_none_stream())
# This test checks if the Device Context manager is a no op
# when the device is none for `with torch.jit.cuda.device`
@torch.jit.script
def test_set_device_none():
device_index = torch.cuda._current_device()
# When device is none, check if this operation is a no-op
with torch.jit.cuda.device(None):
# Check if the current device is the same
is_device_same = torch.cuda._current_device() == device_index
return is_device_same
self.assertTrue(test_set_device_none())
# Check if a CUDA JIT stream is created
# on the _current_device
@torch.jit.script
def test_simple_stream():
device_index = torch.cuda._current_device()
s = torch.jit.cuda.Stream(device_index, 0)
return device_index == s.device_index()
self.assertTrue(test_simple_stream(), "Could not create Stream!")
# Class used to store results for the test: test_get_stream.
class Result(NamedTuple):
t1 : torch.Tensor
t2 : torch.Tensor
is_current_and_default_stream_same : bool
is_default_and_user_stream_not_same : bool
is_stream_set : bool
is_stream_reset : bool
default_stream_query : bool
default_stream_id : int
user_stream_id : int
# The test aims at checking different stream proporties.
@torch.jit.script
def test_get_stream():
device_index = torch.cuda._current_device()
current_stream = torch.cuda.current_stream(device_index)
default_stream = torch.cuda.default_stream(device_index)
user_stream = torch.jit.cuda.Stream(device_index, 0)
# Check if the current and default streams are the same on the device
is_current_and_default_stream_same = current_stream.id() == default_stream.id()
# Check if user stream and default stream are not the same on the device
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
with torch.jit.cuda.stream(user_stream):
is_stream_set = torch.cuda.current_stream(device_index).id() == user_stream.id()
# Check if the stream was reset to current_stream
is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
tensor1 = torch.rand(10000, 10000, device="cuda")
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
default_stream.synchronize()
default_stream_query = default_stream.query()
# Capture all the results in the class Result
res = Result(
tensor1, tensor2, is_current_and_default_stream_same,
is_default_and_user_stream_not_same, is_stream_set,
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
return res
result = test_get_stream()
self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
self.assertTrue(result.is_current_and_default_stream_same)
self.assertTrue(result.is_default_and_user_stream_not_same)
self.assertTrue(result.is_stream_set)
self.assertTrue(result.is_stream_reset)
self.assertTrue(result.default_stream_query)
self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0
self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero
# Test the stream context manager. This test checks if the stream is switched
# to the user stream on using the stream context manager.
@torch.jit.script
def test_stream_context():
device_index = torch.cuda._current_device()
current_stream = torch.cuda.current_stream(device_index)
user_stream = torch.jit.cuda.Stream(device_index, 0)
A = torch.rand(1000, 1000, device="cuda")
with torch.jit.cuda.stream(user_stream):
check = torch.cuda.current_stream(device_index).id() == user_stream.id()
B = torch.mm(A, A).to("cuda")
# Wait for B to be computed
user_stream.synchronize()
# Check if the stream has been reset on the current device
is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
return A, B, check, is_stream_reset
A, B, is_stream_set, is_stream_reset = test_stream_context()
self.assertEqual(torch.matmul(A, A), B)
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
# Test multiple nested streams. Check if the operations are computed as expected on the streams
# This test has been adapted from the eager mode tests available at test/test_cuda.py
@torch.jit.script
def test_multiple_stream():
prev_device_index = torch.cuda._current_device()
prev_current_stream = torch.cuda.current_stream(prev_device_index)
s1 = torch.jit.cuda.Stream(0, 0)
s2 = torch.jit.cuda.Stream(1, 0)
A = torch.rand(1000, 1000, device="cuda")
B = torch.rand(1000, 1000, device="cuda")
with torch.jit.cuda.stream(s1):
C = torch.mm(A, A).to("cuda")
# Check if the stream and device have been set to s1
is_stream_s1 = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
is_device_s1 = torch.cuda._current_device() == s1.device_index()
with torch.jit.cuda.stream(s2):
# Check if the stream and device have been set to s2
is_stream_s2 = torch.cuda.current_stream(s2.device_index()).id() == s2.id()
is_device_s2 = torch.cuda._current_device() == s2.device_index()
D = torch.mm(B, B).to("cuda")
# Check if the stream and device have been set to s1
is_stream_s1_after = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
is_device_s1_after = torch.cuda._current_device() == s1.device_index()
# Wait for D to be computed
s2.synchronize()
# Wait for C to be computed on S1
s1.synchronize()
# Check if the stream and device has been restored to previous stream and device
is_device_current = torch.cuda._current_device() == prev_device_index
is_stream_current = torch.cuda.current_stream(prev_device_index).id() == prev_current_stream.id()
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
return A, B, C, D, check_stream, check_device
A, B, C, D, check_stream, check_device = test_multiple_stream()
self.assertEqual(torch.matmul(A, A), C)
self.assertEqual(torch.matmul(B, B), D)
self.assertTrue(check_stream)
self.assertTrue(check_device)
# Test multiple streams waiting on each other for the operations to be completed.
@torch.jit.script
def test_data_dependency_between_streams():
device_index = torch.cuda._current_device()
prev_current_stream = torch.cuda.current_stream(device_index)
s1 = torch.jit.cuda.Stream(0, 0)
s2 = torch.jit.cuda.Stream(0, 0)
event = torch.jit.cuda.Event(False, False, False)
A = torch.rand(1000, 1000, device="cuda")
with torch.jit.cuda.stream(s1):
is_stream_s1 = torch.cuda.current_stream(device_index).id() == s1.id()
B = torch.mm(A, A).to("cuda")
s1.record_event(event)
# Check if the current_stream is reset
is_current_stream_1 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
# Wait for ops on s1 to be computed
s2.wait_event(event)
with torch.jit.cuda.stream(s2):
is_stream_s2 = torch.cuda.current_stream(device_index).id() == s2.id()
C = torch.mm(B, B).to("cuda")
# Wait for C to be computed
s2.synchronize()
# Check if the current_stream is reset
is_current_stream_2 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
return A, B, C, check_stream
A, B, C, check_stream = test_data_dependency_between_streams()
self.assertEqual(torch.matmul(A, A), B)
self.assertEqual(torch.matmul(B, B), C)
self.assertTrue(check_stream)
# Test a simple CUDA event. Test if the CUDA event was created successfully
@torch.jit.script
def test_simple_event():
e = torch.jit.cuda.Event(True, False, False)
return e is not None
self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
# Record the CUDA event for operation torch.mm on the current stream
# and then test if the elapsed time is greater than 0. This test is also
# an adaption from eager mdoe CUDA tests available at test/test_cuda.py
@torch.jit.script
def test_event():
device_index = torch.cuda._current_device()
stream = torch.cuda.current_stream(device_index)
event = torch.jit.cuda.Event(True, False, False)
is_true_event_query = event.query()
start_event = torch.jit.cuda.Event(True, False, False)
stream.record_event(start_event)
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
stream.record_event(event)
event.synchronize()
is_again_true_event_query = event.query()
if not (is_true_event_query and is_again_true_event_query):
return -1.0
return start_event.elapsed_time(event)
self.assertGreater(test_event(), 0)
# Check for stream synchronization , when a large tensor multiplication is
# computed on the stream. The stream.query should be true once the synchroniztion is done
@torch.jit.script
def test_stream_synchronize() -> float:
device_index = torch.cuda._current_device()
s = torch.jit.cuda.Stream(device_index, 0)
e_tik = torch.jit.cuda.Event(True, False, False)
e_tok = torch.jit.cuda.Event(True, False, False)
e_tik.record(s)
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
with torch.jit.cuda.stream(s):
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
s.synchronize()
e_tok.record(s)
e_tok.synchronize()
if not s.query():
return -1.0
# not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise.
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_stream_synchronize(), 0)
# Test event synchronization for the event that records a stream doing
# a large tensor multiplication. Check if the elapsed time is greater than 0
# and the stream.query evaluates to true.
@torch.jit.script
def test_event_synchronize() -> float:
device_index = torch.cuda._current_device()
s = torch.jit.cuda.Stream(device_index, 0)
e_tik = torch.jit.cuda.Event(True, False, False)
e_tok = torch.jit.cuda.Event(True, False, False)
e_tik.record(s)
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
with torch.jit.cuda.stream(s):
tensor = torch.mm(tensor1, tensor1).to("cuda")
s.record_event(e_tok)
e_tok.synchronize()
s.synchronize()
if not s.query():
return -1.0
# not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise.
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_synchronize(), 0)
# Test for event wait. Check if event waits for the all the operations on
# the stream to be done. Check for synchronizations and query on the streams
# and events. This test is adapted from eager mode tests for CUDA. Please refer
# test/test_cuda.py
@torch.jit.script
def test_event_wait() -> float:
device_index = torch.cuda._current_device()
s0 = torch.cuda.current_stream(device_index)
s1 = torch.jit.cuda.Stream(device_index, 0)
e_tik = torch.jit.cuda.Event(True, True, False)
e_tok = torch.jit.cuda.Event(True, True, False)
e_tik.record(s0)
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
with torch.jit.cuda.stream(s0):
tensor2 = torch.mm(tensor1, tensor1).cuda()
e_sync = torch.jit.cuda.Event(True, False, False)
e_sync.record(torch.cuda.current_stream(device_index))
e_sync.wait(s1)
with torch.jit.cuda.stream(s1):
tensor3 = torch.rand(1000000000, 1000000000, device="cuda")
tensor4 = torch.mm(tensor3, tensor3).cuda()
s1.synchronize()
e_tok.record(torch.cuda.current_stream(device_index))
e_tok.synchronize()
s0.synchronize()
if not s0.query() or not s1.query() or not e_sync.query():
return -1.0
# not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise.
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_wait(), 0)
# Test for stream wait_event. Checks if the stream waits on the event
@torch.jit.script
def test_wait_event():
d1 = torch.device('cuda:1')
with torch.jit.cuda.device(d1):
s0 = torch.cuda.current_stream(1)
tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
e0 = torch.jit.cuda.Event(False, False, False)
s0.record_event(e0)
s1 = torch.cuda.current_stream(0)
s1.wait_event(e0)
s1.synchronize()
return e0.query() and s0.query() and s1.query()
self.assertTrue(test_wait_event())
# Test if a scripted module with cuda streams can be saved, loaded and executed
def test_save_load(self):
class Model(torch.nn.Module):
def forward(self):
device_index = torch.cuda._current_device()
s = torch.jit.cuda.Stream(device_index, 0)
a = torch.rand(3, 4, device="cuda")
b = torch.rand(3, 4, device="cuda")
with torch.jit.cuda.stream(s):
is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
c = torch.cat((a, b), 0).cuda()
s.synchronize()
return is_stream_s, a, b, c
model = Model()
# Script the model and save
script_model = torch.jit.script(model)
is_stream_s, a, b, c = script_model()
# Verify if the output is correct
self.assertTrue(is_stream_s)
self.assertEqual(torch.cat((a, b), 0), c)
# Save and load scripted model
load_model = self.getExportImportCopy(script_model)
is_stream_s, a_load, b_load, c_load = load_model()
self.assertTrue(is_stream_s)
self.assertEqual(torch.cat((a_load, b_load), 0), c_load)

View File

@ -35,6 +35,7 @@ from jit.test_profiler import TestProfiler # noqa: F401
from jit.test_slice import TestSlice # noqa: F401
from jit.test_warn import TestWarn # noqa: F401
from jit.test_isinstance import TestIsinstance # noqa: F401
from jit.test_cuda import TestCUDA # noqa: F401
from jit.test_hash import TestHash # noqa: F401
# Torch

View File

@ -408,6 +408,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
"torch/csrc/jit/codegen/cuda/type.cpp",
"torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
"torch/csrc/jit/runtime/register_cuda_ops.cpp",
]
libtorch_cuda_sources = libtorch_cuda_core_sources + [

179
torch/csrc/jit/cuda/cuda.h Normal file
View File

@ -0,0 +1,179 @@
#include <aten/src/ATen/cuda/CUDAEvent.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/custom_class.h>
namespace torch {
namespace jit {
class CUDAEvent;
// This class is a wrapper around c10::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for c10::cuda::CUDAStream. For more details, please refer to
// c10/cuda/CUDAStream.h.
class CUDAStream final : public CustomClassHolder {
public:
CUDAStream(int64_t device = -1, int64_t priority = 0) {
constexpr int64_t PRIORITY_INDEX = 0;
stream_ = std::make_unique<c10::cuda::CUDAStream>(
c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device));
}
CUDAStream(c10::cuda::CUDAStream s) {
stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
}
bool query() {
return stream_->query();
}
c10::intrusive_ptr<CUDAEvent> recordEvent(
c10::intrusive_ptr<CUDAEvent> event);
void synchronize() {
stream_->synchronize();
}
void waitEvent(c10::intrusive_ptr<CUDAEvent> event);
void waitStream(c10::intrusive_ptr<CUDAStream> stream);
/// Get the CUDA device index that this stream is associated with.
int64_t device_index() const {
return stream_->device_index();
}
/// Get the full Device that this stream is associated with. The Device
/// is guaranteed to be a CUDA device.
c10::Device device() const {
return stream_->device();
}
/// Return the stream ID corresponding to this particular stream.
int64_t id() const {
return stream_->id();
}
/// Pack a CUDAStream to uint64_t representation.
/// The CUDAStream can be unpacked using unpack(). The format of
/// the uint64_t is unspecified and may be changed.
int64_t pack() const {
return stream_->pack();
}
private:
std::unique_ptr<c10::cuda::CUDAStream> stream_;
friend class CUDAEvent;
};
// This class is a wrapper around at::cuda::CUDAStream.
// It is needed because TorchBind does not support all of the argument types
// for at::cuda::CUDAEvent. For more details, please refer to
// aten/src/ATen/cuda/CUDAEvent.h.
class CUDAEvent final : public CustomClassHolder {
public:
CUDAEvent(
bool enable_timing = false,
bool blocking = false,
bool interprocess = false) {
int flags = cudaEventDisableTiming;
if (enable_timing) {
flags = cudaEventDefault;
}
if (blocking) {
flags |= cudaEventBlockingSync;
}
if (interprocess) {
TORCH_CHECK(!enable_timing);
flags |= cudaEventInterprocess;
}
event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
}
double elapsedTime(c10::intrusive_ptr<CUDAEvent> end) {
return event_->elapsed_time(*end->event_);
}
std::string ipcHandle() {
cudaIpcEventHandle_t handle;
event_->ipc_handle(&handle);
std::string str_handle((const char*)&handle, sizeof(handle));
return str_handle;
}
bool query() {
return event_->query();
}
void record(c10::intrusive_ptr<CUDAStream> stream);
void synchronize() {
event_->synchronize();
}
void wait(c10::intrusive_ptr<CUDAStream> stream);
private:
void recordInternal(CUDAStream* stream);
std::unique_ptr<at::cuda::CUDAEvent> event_;
friend class CUDAStream;
};
c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
c10::intrusive_ptr<CUDAEvent> event) {
if (!event) {
event = c10::make_intrusive<CUDAEvent>();
}
event->recordInternal(this);
return event;
}
void CUDAStream::waitEvent(c10::intrusive_ptr<CUDAEvent> event) {
event->event_->block(*stream_);
}
void CUDAStream::waitStream(c10::intrusive_ptr<CUDAStream> stream) {
auto ev = c10::make_intrusive<CUDAEvent>();
stream->recordEvent(ev);
waitEvent(ev);
}
void CUDAEvent::record(c10::intrusive_ptr<CUDAStream> stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::recordInternal(CUDAStream* stream) {
event_->record(*stream->stream_);
}
void CUDAEvent::wait(c10::intrusive_ptr<CUDAStream> stream) {
event_->block(*stream->stream_);
}
TORCH_LIBRARY(cuda, m) {
auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
torch::init<int64_t, int64_t>());
auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
torch::init<bool, bool, bool>());
stream_class.def("query", &CUDAStream::query)
.def("record_event", &CUDAStream::recordEvent)
.def("synchronize", &CUDAStream::synchronize)
.def("wait_event", &CUDAStream::waitEvent)
.def("wait_stream", &CUDAStream::waitStream)
.def("device_index", &CUDAStream::device_index)
.def("device", &CUDAStream::device)
.def("pack", &CUDAStream::pack)
.def("id", &CUDAStream::id);
event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
.def("query", &CUDAEvent::query)
.def("record", &CUDAEvent::record)
.def("synchronize", &CUDAEvent::synchronize)
.def("wait", &CUDAEvent::wait);
};
} // namespace jit
} // namespace torch

View File

@ -211,6 +211,13 @@ TypePtr ScriptTypeParser::parseTypeFromExprImpl(const Expr& expr) const {
}
}
// Check if the type is a custom class. This is done by checking
// if type_name starts with "torch.classes."
if (type_name.find("torch.classes.") == 0) {
auto custom_class_type = getCustomClass("__torch__." + type_name);
return custom_class_type;
}
throw ErrorReport(expr) << "Unknown type name '" << type_name << "'";
} else if (auto name = parseBaseTypeName(expr)) {
auto itr = string_to_type_lut().find(*name);

View File

@ -572,7 +572,8 @@ void AliasDb::analyzeImpl(Node* node) {
!aliasAnalysisHasSpecialCaseFor(node->kind()),
"Special cases should be handled already if we're here.");
if (node->kind().is_aten() || node->kind().is_prim()) {
if (node->kind().is_aten() || node->kind().is_prim() ||
node->kind().is_cuda()) {
// TODO There is nothing in the system that relies on aten:: and prim::
// ops using AliasAnalysisKind::FROM_SCHEMA or
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended

View File

@ -1079,6 +1079,11 @@ bool Node::hasSideEffects() const {
case prim::rpc_sync: // It represents RPC message sent.
case prim::rpc_remote: // It represents RPC message sent.
case aten::wait: // It can represent RPC message received.
#ifndef __HIP_PLATFORM_HCC__
case cuda::set_stream:
case cuda::_set_device:
case cuda::_current_device:
#endif
case prim::Enter:
case prim::Exit:
return true;
@ -1094,7 +1099,7 @@ bool Node::hasSideEffects() const {
return false;
}
if (kind_.is_prim() || kind_.is_aten()) {
if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
// TODO There is nothing in the system that relies on aten:: and prim::
// ops using AliasAnalysisKind::FROM_SCHEMA,
// AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or

View File

@ -72,6 +72,11 @@ using namespace ::c10::attr;
namespace aten {
using namespace ::c10::aten;
}
namespace cuda {
#ifndef __HIP_PLATFORM_HCC__
using namespace ::c10::cuda;
#endif
} // namespace cuda
struct Function;
struct MatchedSchema;

View File

@ -217,6 +217,32 @@ std::shared_ptr<SugaredValue> PythonModuleValue::attr(
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#ifndef __HIP_PLATFORM_HCC__
std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
const SourceRange& loc,
Function& m,
const std::string& field) {
// List of all the cuda operators which are supported in JIT
const std::unordered_set<std::string> cuda_ops = {"current_stream",
"default_stream",
"_current_device",
"_set_device",
"device_index",
"device_count",
"set_stream"};
if (cuda_ops.find(field) != cuda_ops.end()) {
return std::make_shared<BuiltinFunction>(Symbol::cuda(field), c10::nullopt);
}
py::object member = getattr(loc, field);
// note: is_constant = true because we consider that global properties
// on modules like math.pi or torch.float to be constants
// even though it is possible, though rare, for someone to mutate them
return toSugaredValue(member, m, loc, /*is_constant=*/true);
}
#endif
Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
return self_;
}
@ -938,6 +964,12 @@ std::shared_ptr<SugaredValue> toSugaredValue(
if (auto callee = as_function(obj)) {
return std::make_shared<FunctionValue>(callee->function_);
} else if (py::isinstance<py::module>(obj)) {
#ifndef USE_ROCM
std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
if (obj_name.compare("torch.cuda") == 0) {
return std::make_shared<CUDAPythonModuleValue>(obj);
}
#endif
return std::make_shared<PythonModuleValue>(obj);
} else if (
obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||

View File

@ -91,6 +91,20 @@ struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue {
const std::string& field) override;
};
// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
// torch.cuda.* are resolved using CUDAPythonModuleValue.
#ifndef __HIP_PLATFORM_HCC__
struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
explicit CUDAPythonModuleValue(py::object mod)
: PythonValue(std::move(mod)) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& field) override;
};
#endif
// Represents all the parameters of a module as a List[Tensor]
struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
ConstantParameterList(Value* the_list) : the_list_(the_list) {}

View File

@ -0,0 +1,87 @@
// This file registers special JIT operators used to implement the PyTorch CUDA
// API in TorchScript.
#ifndef __HIP_PLATFORM_HCC__
#include <torch/csrc/api/include/torch/utils.h>
#include <torch/csrc/jit/cuda/cuda.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch {
namespace jit {
namespace {
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
RegisterOperators const reg({
Operator(
"cuda::current_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
[](Stack* stack) {
auto idx = uint16_t(pop(stack).toInt());
auto s = c10::cuda::getCurrentCUDAStream(idx);
auto st = make_custom_class<torch::jit::CUDAStream>(s);
push(stack, IValue(st));
},
aliasAnalysisFromSchema()),
Operator(
"cuda::default_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
[](Stack* stack) {
auto idx = uint16_t(pop(stack).toInt());
auto s = c10::cuda::getDefaultCUDAStream(idx);
auto st = make_custom_class<torch::jit::CUDAStream>(s);
push(stack, IValue(st));
},
aliasAnalysisFromSchema()),
Operator(
"cuda::_current_device() -> int",
[](Stack* stack) {
auto v = c10::cuda::current_device();
push(stack, static_cast<int>(v));
},
aliasAnalysisFromSchema()),
Operator(
"cuda::_set_device(int64_t val) -> ()",
[](Stack* stack) {
int64_t idx = -1;
pop(stack, idx);
c10::cuda::set_device(static_cast<c10::DeviceIndex>(idx));
},
aliasAnalysisFromSchema()),
Operator(
"cuda::device_index(Device device) -> int",
[](Stack* stack) {
auto device = pop(stack);
auto idx = device.toDevice().index();
push(stack, idx);
},
aliasAnalysisFromSchema()),
Operator(
"cuda::device_count() -> int",
[](Stack* stack) { push(stack, at::cuda::device_count()); },
aliasAnalysisFromSchema()),
Operator(
"cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()",
[](Stack* stack) {
auto v = pop(stack);
auto s = v.toCustomClass<torch::jit::CUDAStream>();
// To set the current CUDA stream using
// c10::cuda::setCurrentCUDAStream, the jit::CUDAStream object needs
// to be converted to c10::cuda::CUDAStream. Since the latter cannot
// be returned from a class registered via TorchBind, this can only be
// achieved by packing the c10::cuda::CUDAStream instance contained
// inside the jit::CUDAStream object to a uint64_t representation, and
// unpacking it inside this operator. The unpacked stream is then used
// to set the current CUDA stream.
auto packed = s->pack();
auto unpacked = c10::cuda::CUDAStream::unpack(packed);
c10::cuda::setCurrentCUDAStream(unpacked);
},
aliasAnalysisFromSchema()),
});
} // namespace
} // namespace jit
} // namespace torch
#endif

View File

@ -44,6 +44,7 @@ from torch.jit._async import fork, wait
from torch.jit._serialization import save, load
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
from torch.jit.cuda import stream
from torch.jit._freeze import freeze
# For backwards compatibility

182
torch/jit/cuda.py Normal file
View File

@ -0,0 +1,182 @@
# mypy: ignore-errors
r"""
This package adds support for JIT compilation for CUDA Streams and events,
This is similar to API's available in the eager mode
:ref:`cuda-semantics` has more details about working with CUDA.
"""
import torch
from typing import Optional, Any
from torch import device as _device
def get_current_device_index() -> int:
r"""Checks if there are CUDA devices available and
returns the device index of the current default CUDA device.
Returns -1 in case there are no CUDA devices available.
Arguments: ``None``
"""
if torch.cuda.device_count() > 0:
return torch.cuda._current_device()
return -1
def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int:
r"""Gets the device index from :attr:`device`, which can be a torch.device
object, a Python integer, or ``None``.
If :attr:`device` is a torch.device object, returns the device index if it
is a CUDA device. Note that for a CUDA device without a specified index,
, this will return the current default CUDA device if :attr:`optional` is ``True``.
If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be
returned in this case.
If :attr:`device` is a Python integer, it is returned as is.
If :attr:`device` is ``None``, this will return the current default CUDA
device if :attr:`optional` is ``True``.
"""
if device is None:
if optional:
return get_current_device_index()
else:
raise ValueError('Expected a torch.device with a specified index '
f'or an integer, but got: {device}')
device_index = -1
if isinstance(device, str):
device = torch.device(device)
if isinstance(device, torch.device):
if not allow_cpu and device.type == 'cpu':
raise ValueError(f'Expected a non cpu device, but got: {device}')
device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device)
if isinstance(device, int):
device_index = device
return device_index
class device(object):
r"""Context-manager that changes the selected device.
This is similar to device (torch.device or int), but has been
introduced for JIT compatibility.
Arguments:
device (torch.device or int): device index to select. It's a no-op if
this argument is a negative integer or ``None``.
"""
def __init__(self, device: Optional[_device]):
self.idx = -1
self.prev_idx = -1
self.device = device
def __enter__(self):
self.idx = get_device_index(self.device, optional=True)
if self.idx == -1:
return
self.prev_idx = torch.cuda._current_device()
if self.prev_idx != self.idx:
torch.cuda._set_device(self.idx)
def __exit__(self, type: Any, value: Any, traceback: Any):
if self.prev_idx != self.idx:
torch.cuda._set_device(self.prev_idx)
class StreamContext(object):
r"""Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued on a selected
stream.
Arguments:
StreamContext (Stream): selected stream. This manager is a no-op if it's
``None``.
.. note:: Streams are per-device. If the selected stream is not on the
current device, this function will also change the current device to
match the stream.
"""
cur_stream : Optional['torch.classes.cuda.Stream']
def __init__(self, stream: Optional['torch.classes.cuda.Stream']):
self.idx = -1
self.stream = stream
# Initialize the below streams to default stream on the current device
self.device_index = get_current_device_index()
self.src_prev_stream = torch.cuda.default_stream(self.device_index)
self.dst_prev_stream = torch.cuda.default_stream(self.device_index)
def __enter__(self):
self.idx = get_device_index(device=None, optional=True)
# If there is no CUDA device available, return
if self.idx == -1:
return
# Local cur_stream variable for type refinement
cur_stream = self.stream
# Return if stream is None
if cur_stream is None:
return
self.src_prev_stream = torch.cuda.current_stream(self.idx)
# If the stream is not on the current device, then change the device
# and set the current stream on the device
if self.src_prev_stream.device_index() != cur_stream.device_index():
with device(cur_stream.device()):
self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index())
torch.cuda._set_device(cur_stream.device_index())
torch.cuda.set_stream(cur_stream)
def __exit__(self, type: Any, value: Any, traceback: Any):
# Local cur_stream variable for type refinement
cur_stream = self.stream
# If stream is None or no CUDA device available, return
if cur_stream is None or self.idx == -1:
return
# If the stream was not on the current device, restore the previous stream on
# the destination device and also reset the current device to the previous device.
# Set the current stream on the device to the src_prev_stream
if self.src_prev_stream.device_index() != cur_stream.device_index():
torch.cuda.set_stream(self.dst_prev_stream)
torch.cuda._set_device(self.idx)
torch.cuda.set_stream(self.src_prev_stream)
def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext:
r"""Wrapper around the Context-manager that selects a given stream.
All CUDA kernels queued within its context will be enqueued on a selected
stream.
Arguments:
stream (Stream): selected stream. This manager is a no-op if it's
``None``.
"""
return StreamContext(stream)
def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream':
r"""Wrapper around a CUDA stream.
A CUDA stream is a linear sequence of execution that belongs to a specific
device, independent from other streams. See :ref:`cuda-semantics` for
details.
Arguments:
device(int, optional): a device on which to allocate
the stream. If :attr:`device` is ``None`` (default) or a negative
integer, this will use the current device.
priority(int, optional): priority of the stream. Can be either
-1 (high priority) or 0 (low priority). By default, streams have
priority 0.
.. note:: Although CUDA versions >= 11 support more than two levels of
priorities, in PyTorch, we only support two levels of priorities.
"""
return torch.classes.cuda.Stream(device, priority)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event':
r"""Wrapper around a CUDA event.
CUDA events are synchronization markers that can be used to monitor the
device's progress, to accurately measure timing, and to synchronize CUDA
streams.
Arguments:
enable_timing (bool, optional): indicates if the event should measure time
(default: ``False``)
blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
interprocess (bool): if ``True``, the event can be shared between processes
(default: ``False``)
.. _CUDA Event Documentation:
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
"""
return torch.classes.cuda.Event(enable_timing, blocking, interprocess)