[Static Runtime] test case for staticRuntime::runAsync() API (#80407)

Summary:
- Python interface to call StaticRuntime::runAsync API
- creates a custom executor with execution on inter-op thread pool
- test cases for different async graph scenarios like multiple forks, nested forks, exception handling

Test Plan:
- local tests
    buck test mode/opt caffe2/test:static_runtime
    buck test mode/opt caffe2/benchmarks/static_runtime/fb:test_fb_operators
    buck test mode/opt caffe2/benchmarks/static_runtime:static_runtime_cpptest
- OSS CI tests

Differential Revision: D37471859

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80407
Approved by: https://github.com/tenpercent
This commit is contained in:
Akshay Parashar
2022-07-05 23:40:53 +00:00
committed by PyTorch MergeBot
parent 9402219a36
commit fefdad6137
2 changed files with 114 additions and 0 deletions

View File

@ -23,6 +23,9 @@ class StaticModule:
def benchmark(self, args, kwargs, warmup_runs, main_runs):
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
def runAsync(self, args, kwargs):
return self.static_module.runAsync(args, kwargs)
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
return self.static_module.benchmark_individual_ops(
args, kwargs, warmup_runs, main_runs
@ -222,6 +225,20 @@ class TestStaticModule(TestCase):
output_test = static_runtime_module(inp1, inp2)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test simple fork/wait operation with
StaticRuntime runAsync API returning future
"""
def test_fork_wait_1_async(self):
inp1 = torch.ones(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph1)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((inp1, inp2), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
a loop subgraph performing mix of operations
@ -235,6 +252,20 @@ class TestStaticModule(TestCase):
output_test = static_runtime_module(inp1, inp2)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation on a loop
subgraph with StaticRuntime runAsync API returning future
"""
def test_fork_wait_2_async(self):
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
torch_graph = torch.jit.script(fork_wait_graph2)
output_ref = torch_graph(inp1, inp2)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((inp1, inp2), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
having multiple fork/wait operations
@ -247,6 +278,21 @@ class TestStaticModule(TestCase):
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module(input, num_forks)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation in a graph with
multiple fork/wait operations on runAsync API returning future
"""
def test_fork_wait_3_async(self):
input = torch.ones(3, 3)
num_forks = 10
torch_graph = torch.jit.script(fork_wait_graph3)
output_ref = torch_graph(input, num_forks)
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync((input, num_forks), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test fork/wait operation in a graph on
multiple nested fork/wait operations
@ -261,6 +307,22 @@ class TestStaticModule(TestCase):
output_test = static_runtime_module(input, num_forks, num_child_forks)
torch.testing.assert_close(output_test, output_ref)
"""
Test Case: To test fork/wait operation in a graph with multiple
nested fork/wait operations on runAsync API returning future
"""
def test_fork_wait_4_async(self):
input = torch.ones(3, 3)
num_forks = 10
num_child_forks = 10
torch_graph = torch.jit.script(fork_wait_graph4)
static_runtime_module = StaticModule(torch_graph)
output_ref = torch_graph(input, num_forks, num_child_forks)
output_test = static_runtime_module.runAsync(
(input, num_forks, num_child_forks), {})
output_test.wait()
torch.testing.assert_close(output_test.value(), output_ref)
"""
Test Case: To test exception handling in fork/wait
operation. Add.Tensor op is called for tensors with
@ -290,6 +352,36 @@ class TestStaticModule(TestCase):
f"not contain expected substring: \"{expected_error_msg}\""
) from error
"""
Test Case: To test exception handling in fork/wait
operation with runAsync API. Add.Tensor op is called for
tensors with non-matching dims on the forked subgraph
and the exception raised by subgraph is set on future returned
by prim::fork to parent graph. Returned exception is
checked for substring expected_error_msg as declared below
"""
def test_fork_wait_exception_async(self):
# incompatible tensors for add due to shape mismatch
input1 = torch.randn(4, 7)
input2 = torch.randn(4, 5)
torch_graph = torch.jit.script(fork_wait_graph_exception)
try:
static_runtime_module = StaticModule(torch_graph)
output_test = static_runtime_module.runAsync(
(input1, input2), {})
except Exception as error:
expected_error_msg = (
"The size of tensor a (7) must match the size "
"of tensor b (5) at non-singleton dimension 1"
)
# test fails if error does not contain expected substr
if str(error).find(expected_error_msg) == -1:
raise RuntimeError(
"Tried execution of add.Tensors with incompatible shape. "
"Exception raised by forked runtime execution does "
f"not contain expected substring: \"{expected_error_msg}\""
) from error
def test_multihead_attention_layer(self):
HID_DIM = 256
QUERY_LEN = 8