mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Static Runtime] test case for staticRuntime::runAsync() API (#80407)
Summary: - Python interface to call StaticRuntime::runAsync API - creates a custom executor with execution on inter-op thread pool - test cases for different async graph scenarios like multiple forks, nested forks, exception handling Test Plan: - local tests buck test mode/opt caffe2/test:static_runtime buck test mode/opt caffe2/benchmarks/static_runtime/fb:test_fb_operators buck test mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest - OSS CI tests Differential Revision: D37471859 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80407 Approved by: https://github.com/tenpercent
This commit is contained in:
committed by
PyTorch MergeBot
parent
9402219a36
commit
fefdad6137
@ -23,6 +23,9 @@ class StaticModule:
|
||||
def benchmark(self, args, kwargs, warmup_runs, main_runs):
|
||||
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
|
||||
|
||||
def runAsync(self, args, kwargs):
|
||||
return self.static_module.runAsync(args, kwargs)
|
||||
|
||||
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
|
||||
return self.static_module.benchmark_individual_ops(
|
||||
args, kwargs, warmup_runs, main_runs
|
||||
@ -222,6 +225,20 @@ class TestStaticModule(TestCase):
|
||||
output_test = static_runtime_module(inp1, inp2)
|
||||
torch.testing.assert_close(output_test, output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test simple fork/wait operation with
|
||||
StaticRuntime runAsync API returning future
|
||||
"""
|
||||
def test_fork_wait_1_async(self):
|
||||
inp1 = torch.ones(5, 5)
|
||||
inp2 = torch.randn(5, 5)
|
||||
torch_graph = torch.jit.script(fork_wait_graph1)
|
||||
output_ref = torch_graph(inp1, inp2)
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_test = static_runtime_module.runAsync((inp1, inp2), {})
|
||||
output_test.wait()
|
||||
torch.testing.assert_close(output_test.value(), output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation in a graph on
|
||||
a loop subgraph performing mix of operations
|
||||
@ -235,6 +252,20 @@ class TestStaticModule(TestCase):
|
||||
output_test = static_runtime_module(inp1, inp2)
|
||||
torch.testing.assert_close(output_test, output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation on a loop
|
||||
subgraph with StaticRuntime runAsync API returning future
|
||||
"""
|
||||
def test_fork_wait_2_async(self):
|
||||
inp1 = torch.randn(5, 5)
|
||||
inp2 = torch.randn(5, 5)
|
||||
torch_graph = torch.jit.script(fork_wait_graph2)
|
||||
output_ref = torch_graph(inp1, inp2)
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_test = static_runtime_module.runAsync((inp1, inp2), {})
|
||||
output_test.wait()
|
||||
torch.testing.assert_close(output_test.value(), output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation in a graph on
|
||||
having multiple fork/wait operations
|
||||
@ -247,6 +278,21 @@ class TestStaticModule(TestCase):
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_test = static_runtime_module(input, num_forks)
|
||||
torch.testing.assert_close(output_test, output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation in a graph with
|
||||
multiple fork/wait operations on runAsync API returning future
|
||||
"""
|
||||
def test_fork_wait_3_async(self):
|
||||
input = torch.ones(3, 3)
|
||||
num_forks = 10
|
||||
torch_graph = torch.jit.script(fork_wait_graph3)
|
||||
output_ref = torch_graph(input, num_forks)
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_test = static_runtime_module.runAsync((input, num_forks), {})
|
||||
output_test.wait()
|
||||
torch.testing.assert_close(output_test.value(), output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation in a graph on
|
||||
multiple nested fork/wait operations
|
||||
@ -261,6 +307,22 @@ class TestStaticModule(TestCase):
|
||||
output_test = static_runtime_module(input, num_forks, num_child_forks)
|
||||
torch.testing.assert_close(output_test, output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test fork/wait operation in a graph with multiple
|
||||
nested fork/wait operations on runAsync API returning future
|
||||
"""
|
||||
def test_fork_wait_4_async(self):
|
||||
input = torch.ones(3, 3)
|
||||
num_forks = 10
|
||||
num_child_forks = 10
|
||||
torch_graph = torch.jit.script(fork_wait_graph4)
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_ref = torch_graph(input, num_forks, num_child_forks)
|
||||
output_test = static_runtime_module.runAsync(
|
||||
(input, num_forks, num_child_forks), {})
|
||||
output_test.wait()
|
||||
torch.testing.assert_close(output_test.value(), output_ref)
|
||||
|
||||
"""
|
||||
Test Case: To test exception handling in fork/wait
|
||||
operation. Add.Tensor op is called for tensors with
|
||||
@ -290,6 +352,36 @@ class TestStaticModule(TestCase):
|
||||
f"not contain expected substring: \"{expected_error_msg}\""
|
||||
) from error
|
||||
|
||||
"""
|
||||
Test Case: To test exception handling in fork/wait
|
||||
operation with runAsync API. Add.Tensor op is called for
|
||||
tensors with non-matching dims on the forked subgraph
|
||||
and the exception raised by subgraph is set on future returned
|
||||
by prim::fork to parent graph. Returned exception is
|
||||
checked for substring expected_error_msg as declared below
|
||||
"""
|
||||
def test_fork_wait_exception_async(self):
|
||||
# incompatible tensors for add due to shape mismatch
|
||||
input1 = torch.randn(4, 7)
|
||||
input2 = torch.randn(4, 5)
|
||||
torch_graph = torch.jit.script(fork_wait_graph_exception)
|
||||
try:
|
||||
static_runtime_module = StaticModule(torch_graph)
|
||||
output_test = static_runtime_module.runAsync(
|
||||
(input1, input2), {})
|
||||
except Exception as error:
|
||||
expected_error_msg = (
|
||||
"The size of tensor a (7) must match the size "
|
||||
"of tensor b (5) at non-singleton dimension 1"
|
||||
)
|
||||
# test fails if error does not contain expected substr
|
||||
if str(error).find(expected_error_msg) == -1:
|
||||
raise RuntimeError(
|
||||
"Tried execution of add.Tensors with incompatible shape. "
|
||||
"Exception raised by forked runtime execution does "
|
||||
f"not contain expected substring: \"{expected_error_msg}\""
|
||||
) from error
|
||||
|
||||
def test_multihead_attention_layer(self):
|
||||
HID_DIM = 256
|
||||
QUERY_LEN = 8
|
||||
|
Reference in New Issue
Block a user