mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964 Approved by: https://github.com/justinchuby, https://github.com/albanD
555 lines
17 KiB
Python
555 lines
17 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
|
|
import os
|
|
import sys
|
|
from typing import Any, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from typing import List
|
|
|
|
from torch import Tensor
|
|
from torch.jit import Future
|
|
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
|
|
|
|
|
|
class TestAsync(JitTestCase):
|
|
def test_async_python(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x)
|
|
|
|
x = torch.rand(3, 4)
|
|
fut = torch.jit.fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit.wait(fut)
|
|
# assert nothing; only to make sure the fake python path works
|
|
|
|
def test_async_future_type_python(self):
|
|
def foo(inp):
|
|
futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
|
|
for i in range(5):
|
|
futures.append(torch.jit.fork(lambda x: x, inp))
|
|
all_outputs = []
|
|
for future in futures:
|
|
all_outputs.append(torch.jit.wait(future))
|
|
return all_outputs
|
|
|
|
# assert nothing, just to make sure python type parsing works
|
|
foo(torch.randn(3, 4))
|
|
|
|
def test_async_parsing(self):
|
|
@torch.jit.script
|
|
def foo(x: Tensor) -> List[Tensor]:
|
|
return [torch.neg(x), x.t()]
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
|
|
for _ in range(3):
|
|
future = torch.jit.annotate(
|
|
Future[List[Tensor]], torch.jit.fork(foo, x)
|
|
)
|
|
futures.append(future)
|
|
|
|
output = torch.jit.annotate(List[List[Tensor]], [])
|
|
for i in range(3):
|
|
output.append(torch.jit.wait(futures[i]))
|
|
return output
|
|
|
|
x = torch.rand(3, 3)
|
|
result = bar(x)
|
|
self.assertEqual(len(result), 3)
|
|
|
|
def test_async_script(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x), x
|
|
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit.fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit.wait(fut)
|
|
return y, y_hat
|
|
|
|
y, y_hat = wait_script(x)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_capture(self):
|
|
class Mod(torch.jit.ScriptModule):
|
|
__constants__ = ["const"]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.const = 42
|
|
self.param = nn.Parameter(torch.randn(2, 2))
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, x1, x2):
|
|
return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x1, x2):
|
|
fut = torch.jit.fork(self.foo, x1, x2)
|
|
y_hat = self.foo(x1, x2)
|
|
y = torch.jit.wait(fut)
|
|
return y, y_hat
|
|
|
|
x1 = torch.rand(3, 4)
|
|
x2 = torch.rand(5, 6)
|
|
|
|
m = Mod()
|
|
|
|
with torch.jit.optimized_execution(False):
|
|
y, y_hat = m.forward(x1, x2)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_nested(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x), x
|
|
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit._wait(fut)
|
|
return y, y_hat
|
|
|
|
@torch.jit.script
|
|
def wait_script_nest(x):
|
|
fut = torch.jit._fork(wait_script, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
y, y_hat = wait_script_nest(x)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_no_script_mod(self):
|
|
x = torch.rand(3, 4)
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "cannot call a value", "torch.jit._fork(x"
|
|
):
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(x)
|
|
return fut
|
|
|
|
def test_async_script_multi_waits(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x).t() + x
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
|
|
# wait twice on the same future
|
|
y1 = torch.jit._wait(fut)
|
|
y2 = torch.jit._wait(fut)
|
|
return y1, y2
|
|
|
|
x = torch.rand(2, 2)
|
|
y1, y2 = wait_script(x)
|
|
self.assertEqual(y1, y2)
|
|
|
|
def test_async_script_multi_forks(self):
|
|
@torch.jit.script
|
|
def foo1(x):
|
|
return torch.neg(x).t() + x
|
|
|
|
@torch.jit.script
|
|
def foo2(x, y):
|
|
return torch.neg(x).t() + x + torch.neg(y).t()
|
|
|
|
@torch.jit.script
|
|
def foo3(x, y, z):
|
|
return torch.neg(z).t() + y.t() + x
|
|
|
|
x1 = torch.rand(10, 10)
|
|
x2 = torch.rand(10, 10)
|
|
x3 = torch.rand(10, 10)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x1, x2, x3):
|
|
f1 = torch.jit._fork(foo1, x1)
|
|
f2 = torch.jit._fork(foo2, x1, x2)
|
|
f3 = torch.jit._fork(foo3, x1, x2, x3)
|
|
f4 = torch.jit._fork(foo1, x2)
|
|
f5 = torch.jit._fork(foo2, x2, x3)
|
|
|
|
# ignore some forks
|
|
y1 = torch.jit._wait(f1)
|
|
y2 = torch.jit._wait(f2)
|
|
y3 = torch.jit._wait(f3)
|
|
|
|
return y1, y2, y3
|
|
|
|
y1, y2, y3 = wait_script(x1, x2, x3)
|
|
self.assertEqual(y1, foo1(x1))
|
|
self.assertEqual(y2, foo2(x1, x2))
|
|
self.assertEqual(y3, foo3(x1, x2, x3))
|
|
|
|
def test_async_kwargs(self):
|
|
def foo(x1, x2):
|
|
return 2 * x1 + x2
|
|
|
|
x1 = torch.rand(3, 4)
|
|
x2 = torch.rand(3, 4)
|
|
y_hat = foo(x1, x2)
|
|
|
|
# Cover tracing and bare functions with permutations of args, kwargs
|
|
for func in [
|
|
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
|
|
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
|
|
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
|
|
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
|
|
]:
|
|
for wrapper in [
|
|
func,
|
|
torch.jit.trace(func, (x1, x2)),
|
|
]:
|
|
self.assertEqual(wrapper(x1, x2), y_hat)
|
|
self.assertEqual(wrapper(x1, x2=x2), y_hat)
|
|
self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
|
|
self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
|
|
|
|
# Cover scripting
|
|
@torch.jit.script
|
|
def foo_script_args(x1, x2):
|
|
return torch.jit._wait(torch.jit._fork(foo, x1, x2))
|
|
|
|
@torch.jit.script
|
|
def foo_script_kwargs(x1, x2):
|
|
return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
|
|
|
|
for wrapper in [
|
|
foo_script_args,
|
|
foo_script_kwargs,
|
|
]:
|
|
self.assertEqual(wrapper(x1, x2), y_hat)
|
|
self.assertEqual(wrapper(x1, x2=x2), y_hat)
|
|
self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
|
|
self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
|
|
|
|
@_inline_everything
|
|
def test_async_script_trace(self):
|
|
class Traced(nn.Module):
|
|
def forward(self, x):
|
|
return (torch.neg(x), x)
|
|
|
|
class Mod(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
x = torch.rand(3, 3)
|
|
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
|
|
|
@torch.jit.script_method
|
|
def forward(
|
|
self, x: Tensor
|
|
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
|
|
future1 = torch.jit._fork(self.traced, x)
|
|
future2 = torch.jit._fork(torch.neg, x)
|
|
|
|
tensor_tuple = torch.jit._wait(future1)
|
|
tensor_single = torch.jit._wait(future2)
|
|
|
|
tensor_list = []
|
|
tensor_list.append(tensor_tuple[0])
|
|
tensor_list.append(tensor_single)
|
|
|
|
# return a nested structure of tensors
|
|
return (tensor_list, tensor_tuple, tensor_tuple[1])
|
|
|
|
class TupleCl(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.module = Mod()
|
|
|
|
def forward(self, x):
|
|
z = torch.neg(x)
|
|
y = self.module(x)
|
|
list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
|
|
return tuple(list)
|
|
|
|
x = torch.rand(3, 3)
|
|
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
|
|
|
|
# Make sure we have forks
|
|
self.assertGraphContainsExactly(
|
|
module.graph, kind="prim::fork", num_kind_nodes=2
|
|
)
|
|
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
|
|
self.assertGraphContainsExactly(
|
|
module.graph, kind="aten::neg", num_kind_nodes=1
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
|
|
)
|
|
|
|
y = torch.neg(x)
|
|
self.assertEqual(module(x), (y, y, y, y, x, x))
|
|
|
|
def test_async_script_error(self):
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# error here
|
|
return x.t() + x
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
@torch.jit.script
|
|
def wait_script_nest(x):
|
|
fut = torch.jit._fork(wait_script, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
# no future
|
|
error_msg = "The size.*must match the size of tensor"
|
|
with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
|
|
foo(x)
|
|
|
|
# one future
|
|
with self.assertRaisesRegexWithHighlight(
|
|
Exception, error_msg, "torch.jit._fork(foo, x"
|
|
):
|
|
wait_script(x)
|
|
|
|
# two futures with a different error
|
|
x = torch.rand(3, 4, 5)
|
|
with self.assertRaisesRegexWithHighlight(
|
|
Exception,
|
|
"expects a tensor with <= 2 dimensions",
|
|
"torch.jit._fork(wait_script, x",
|
|
):
|
|
wait_script_nest(x)
|
|
|
|
def test_async_grad_guard_with_grad(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
y = x * 2
|
|
return y.requires_grad
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
requires_grad_in_fork = torch.jit._wait(fut)
|
|
z = x * 2
|
|
return (requires_grad_in_fork, z.requires_grad)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
|
|
with torch.enable_grad():
|
|
(inside_fork, after_wait) = bar(x)
|
|
|
|
self.assertEqual(inside_fork, True)
|
|
self.assertEqual(after_wait, True)
|
|
|
|
def test_async_grad_guard_no_grad(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
y = x * 2
|
|
return y.requires_grad
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
requires_grad_in_fork = torch.jit._wait(fut)
|
|
z = x * 2
|
|
return (requires_grad_in_fork, z.requires_grad)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
(inside_fork, after_wait) = bar(x)
|
|
|
|
self.assertEqual(inside_fork, False)
|
|
self.assertEqual(after_wait, False)
|
|
|
|
def test_trace_fork_wait(self):
|
|
def fork_body(x):
|
|
return x.neg(), x.neg() + 1
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
vals = torch.jit._wait(fut)
|
|
return vals[0], vals[1], x - 1
|
|
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(fn(x), traced(x))
|
|
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="prim::fork", num_kind_nodes=1
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="aten::wait", num_kind_nodes=1
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
|
|
)
|
|
|
|
def test_trace_fork_wait_leaking(self):
|
|
my_list = []
|
|
|
|
def fork_body(x):
|
|
my_list.append(x + 1)
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
val = torch.jit._wait(fut)
|
|
return my_list[0]
|
|
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError,
|
|
"did not have observable data dependence with trace inputs; "
|
|
"this probably indicates your program cannot be understood "
|
|
"by the tracer.",
|
|
"",
|
|
):
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
|
|
|
|
def test_trace_fork_wait_inline(self):
|
|
def fork_body(x):
|
|
return x + 1, x + 2
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
val = torch.jit._wait(fut)
|
|
return val[1]
|
|
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
|
|
torch._C._jit_pass_inline_fork_wait(traced.graph)
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="prim::fork", num_kind_nodes=0
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="aten::wait", num_kind_nodes=0
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
traced.graph, kind="aten::add", num_kind_nodes=2
|
|
)
|
|
|
|
def test_trace_fork_wait_list_modulecalls(self):
|
|
def add_one(input):
|
|
return input + torch.ones(input.size())
|
|
|
|
class TestListFutureModule(nn.Module):
|
|
def forward(self, input):
|
|
input_list = []
|
|
for i in range(3):
|
|
input_list.append(input)
|
|
|
|
fut_list: List[Future[torch.Tensor]] = []
|
|
for input_tensor in input_list:
|
|
fut_list.append(torch.jit._fork(add_one, input_tensor))
|
|
# return list[future[tensor]] here to ensure tracing
|
|
# module calls return the correct types
|
|
return fut_list
|
|
|
|
class TestModuleWrapper(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.list_fut_mod = TestListFutureModule()
|
|
|
|
def forward(self, input):
|
|
fut_list = self.list_fut_mod(input)
|
|
res = input
|
|
for fut in fut_list:
|
|
res = res + fut.wait()
|
|
return res
|
|
|
|
self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),))
|
|
|
|
def test_trace_modulecalls_with_different_output_types(self):
|
|
def add_one(input):
|
|
return input + torch.ones(input.size())
|
|
|
|
class DifferentOutputModule(nn.Module):
|
|
def forward(self, input):
|
|
fut_res = torch.jit._fork(add_one, (input))
|
|
|
|
# return different types from module call
|
|
return input, fut_res
|
|
|
|
class TestModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.gen_output = DifferentOutputModule()
|
|
|
|
def forward(self, input):
|
|
res, fut_res = self.gen_output(input)
|
|
res = res + fut_res.wait()
|
|
return res
|
|
|
|
self.checkTrace(TestModule(), (torch.randn(5, 5),))
|
|
|
|
def test_no_future_subtype_message(self):
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError, "Future without a contained type", ""
|
|
):
|
|
|
|
@torch.jit.script
|
|
def forward(self, x):
|
|
futs = torch.jit.annotate(List[torch.jit.Future], [])
|
|
|
|
def test_future_subtyping(self):
|
|
"""
|
|
Test that futures subtype each other properly.
|
|
"""
|
|
|
|
# Successful subtyping.
|
|
def returns_int(x: int) -> int:
|
|
return x + x + 1
|
|
|
|
def returns_future_any(x: int) -> torch.jit.Future[Any]:
|
|
return torch.jit._fork(returns_int, (x))
|
|
|
|
@torch.jit.script
|
|
def fn_int(x: int) -> Any:
|
|
fut = returns_future_any(x)
|
|
return fut.wait()
|
|
|
|
# Unsuccessful subtyping.
|
|
with self.assertRaisesRegexWithHighlight(
|
|
RuntimeError,
|
|
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
|
|
"fut = returns_future_float(x",
|
|
):
|
|
|
|
def returns_future_float(x: int) -> torch.jit.Future[float]:
|
|
return torch.jit._fork(returns_int, (x))
|
|
|
|
@torch.jit.script
|
|
def fn_float(x: int) -> Any:
|
|
fut = returns_future_float(x)
|
|
return fut.wait()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|