Enable UFMT on all of test/jit (#123623)

Partially addresses #123062

Ran lintrunner on:

- `test/jit`

with command:

```bash
lintrunner -a --take UFMT --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanhao Ji
2024-04-11 23:45:05 +00:00
committed by PyTorch MergeBot
parent d0ccf599cc
commit 604c9c5601
82 changed files with 6281 additions and 3210 deletions

View File

@ -1162,94 +1162,6 @@ exclude_patterns = [
'test/functorch/test_vmap.py',
'test/functorch/test_vmap_registrations.py',
'test/functorch/xfail_suggester.py',
'test/jit/__init__.py',
'test/jit/_imported_class_test/__init__.py',
'test/jit/_imported_class_test/bar.py',
'test/jit/_imported_class_test/foo.py',
'test/jit/_imported_class_test/very/__init__.py',
'test/jit/_imported_class_test/very/very/__init__.py',
'test/jit/_imported_class_test/very/very/nested.py',
'test/jit/fixtures_srcs/__init__.py',
'test/jit/fixtures_srcs/fixtures_src.py',
'test/jit/fixtures_srcs/generate_models.py',
'test/jit/fixtures_srcs/test_upgrader_models_generation.py',
'test/jit/myexception.py',
'test/jit/test_alias_analysis.py',
'test/jit/test_async.py',
'test/jit/test_aten_pow.py',
'test/jit/test_attr.py',
'test/jit/test_autodiff.py',
'test/jit/test_autodiff_subgraph_slicing.py',
'test/jit/test_await.py',
'test/jit/test_backend_nnapi.py',
'test/jit/test_backends.py',
'test/jit/test_batch_mm.py',
'test/jit/test_builtins.py',
'test/jit/test_class_type.py',
'test/jit/test_complex.py',
'test/jit/test_complexity.py',
'test/jit/test_convert_activation.py',
'test/jit/test_cuda.py',
'test/jit/test_custom_operators.py',
'test/jit/test_data_parallel.py',
'test/jit/test_dataclasses.py',
'test/jit/test_dce.py',
'test/jit/test_device_analysis.py',
'test/jit/test_dtype_analysis.py',
'test/jit/test_enum.py',
'test/jit/test_exception.py',
'test/jit/test_freezing.py',
'test/jit/test_functional_blocks.py',
'test/jit/test_fuser_common.py',
'test/jit/test_graph_rewrite_passes.py',
'test/jit/test_hash.py',
'test/jit/test_hooks.py',
'test/jit/test_hooks_modules.py',
'test/jit/test_ignorable_args.py',
'test/jit/test_ignore_context_manager.py',
'test/jit/test_isinstance.py',
'test/jit/test_jit_utils.py',
'test/jit/test_list_dict.py',
'test/jit/test_logging.py',
'test/jit/test_misc.py',
'test/jit/test_models.py',
'test/jit/test_module_apis.py',
'test/jit/test_module_containers.py',
'test/jit/test_module_interface.py',
'test/jit/test_modules.py',
'test/jit/test_op_decompositions.py',
'test/jit/test_optimize_for_mobile_preserve_debug_info.py',
'test/jit/test_parametrization.py',
'test/jit/test_pdt.py',
'test/jit/test_peephole.py',
'test/jit/test_profiler.py',
'test/jit/test_python_bindings.py',
'test/jit/test_python_builtins.py',
'test/jit/test_python_ir.py',
'test/jit/test_recursive_script.py',
'test/jit/test_remove_mutation.py',
'test/jit/test_save_load.py',
'test/jit/test_save_load_for_op_version.py',
'test/jit/test_script_profile.py',
'test/jit/test_scriptmod_ann.py',
'test/jit/test_slice.py',
'test/jit/test_sparse.py',
'test/jit/test_string_formatting.py',
'test/jit/test_symbolic_shape_analysis.py',
'test/jit/test_tensor_creation_ops.py',
'test/jit/test_tensor_methods.py',
'test/jit/test_torchbind.py',
'test/jit/test_tracer.py',
'test/jit/test_type_sharing.py',
'test/jit/test_types.py',
'test/jit/test_typing.py',
'test/jit/test_union.py',
'test/jit/test_unsupported_ops.py',
'test/jit/test_upgraders.py',
'test/jit/test_warn.py',
'test/jit/test_with.py',
'test/jit/xnnpack/test_xnnpack_delegate.py',
'test/jit_hooks/model.py',
'test/lazy/__init__.py',
'test/lazy/test_bindings.py',
'test/lazy/test_debug_util.py',

View File

@ -1,4 +1,5 @@
import torch
# This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports

View File

@ -1,5 +1,7 @@
import torch
from . import bar
# This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports

View File

@ -1,4 +1,5 @@
import torch
# This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports

View File

@ -1,6 +1,8 @@
import torch
from typing import Union
import torch
class TestVersionedDivTensorExampleV7(torch.nn.Module):
def forward(self, a, b):
result_0 = a / b
@ -8,35 +10,52 @@ class TestVersionedDivTensorExampleV7(torch.nn.Module):
result_2 = a.div(b)
return result_0, result_1, result_2
class TestVersionedLinspaceV7(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
c = torch.linspace(a, b, steps=5)
d = torch.linspace(a, b)
return c, d
class TestVersionedLinspaceOutV7(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.linspace(a, b, out=out)
class TestVersionedLogspaceV8(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
c = torch.logspace(a, b, steps=5)
d = torch.logspace(a, b)
return c, d
class TestVersionedLogspaceOutV8(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.logspace(a, b, out=out)
class TestVersionedGeluV9(torch.nn.Module):
def forward(self, x):
return torch._C._nn.gelu(x)
class TestVersionedGeluOutV9(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out)
class TestVersionedRandomV10(torch.nn.Module):
def forward(self, x):
out = torch.zeros_like(x)

View File

@ -6,9 +6,11 @@ from pathlib import Path
from typing import Set
import torch
# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -105,28 +107,41 @@ ALL_MODULES = {
Get the path to `test/jit/fixtures`, where all test models for operator changes
(upgrader/downgrader) are stored
"""
def get_fixtures_path() -> Path:
pytorch_dir = Path(__file__).resolve().parents[3]
fixtures_path = pytorch_dir / "test" / "jit" / "fixtures"
return fixtures_path
"""
Get all models' name in `test/jit/fixtures`
"""
def get_all_models(model_directory_path: Path) -> Set[str]:
files_in_fixtures = model_directory_path.glob('**/*')
all_models_from_fixtures = [fixture.stem for fixture in files_in_fixtures if fixture.is_file()]
files_in_fixtures = model_directory_path.glob("**/*")
all_models_from_fixtures = [
fixture.stem for fixture in files_in_fixtures if fixture.is_file()
]
return set(all_models_from_fixtures)
"""
Check if a given model already exist in `test/jit/fixtures`
"""
def model_exist(model_file_name: str, all_models: Set[str]) -> bool:
return model_file_name in all_models
"""
Get the operator list given a module
"""
def get_operator_list(script_module: torch) -> Set[str]:
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -134,21 +149,25 @@ def get_operator_list(script_module: torch) -> Set[str]:
operator_list = _export_operator_list(mobile_module)
return operator_list
"""
Get the output model operator version, given a module
"""
def get_output_model_version(script_module: torch.nn.Module) -> int:
buffer = io.BytesIO()
torch.jit.save(script_module, buffer)
buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer)
try:
version = int(zipped_model.read('archive/version').decode("utf-8"))
version = int(zipped_model.read("archive/version").decode("utf-8"))
return version
except KeyError:
version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
return version
"""
Loop through all test modules. If the corresponding model doesn't exist in
`test/jit/fixtures`, generate one. For the following reason, a model won't be exported:
@ -165,6 +184,8 @@ likely this script is running with the commit to make the change.
3. The model already exists in `test/jit/fixtures`.
"""
def generate_models(model_directory_path: Path):
all_models = get_all_models(model_directory_path)
for a_module, expect_operator in ALL_MODULES.items():
@ -176,13 +197,17 @@ def generate_models(model_directory_path: Path):
"The module %s "
"is not a torch.nn.module instance. "
"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
"and it's registered as an instance in ALL_MODULES in generated_models.py", torch_module_name)
"and it's registered as an instance in ALL_MODULES in generated_models.py",
torch_module_name,
)
# The corresponding model name is: test_versioned_div_tensor_example_v4
model_name = ''.join([
'_' + char.lower() if char.isupper() else char for char in torch_module_name
]).lstrip('_')
model_name = "".join(
[
"_" + char.lower() if char.isupper() else char
for char in torch_module_name
]
).lstrip("_")
# Some models may not compile anymore, so skip the ones
# that already has pt file for them.
@ -199,7 +224,10 @@ def generate_models(model_directory_path: Path):
logger.error(
"Actual model version %s "
"is equal or larger than %s + 1. "
"Please run the script before the commit to change operator.", actual_model_version, current_operator_version)
"Please run the script before the commit to change operator.",
actual_model_version,
current_operator_version,
)
continue
actual_operator_list = get_operator_list(script_module)
@ -207,16 +235,23 @@ def generate_models(model_directory_path: Path):
logger.error(
"The model includes operator: %s, "
"however it doesn't cover the operator %s."
"Please ensure the output model includes the tested operator.", actual_operator_list, expect_operator)
"Please ensure the output model includes the tested operator.",
actual_operator_list,
expect_operator,
)
continue
export_model_path = str(model_directory_path / (str(model_name) + ".ptl"))
script_module._save_for_lite_interpreter(export_model_path)
logger.info("Generating model %s and it's save to %s", model_name, export_model_path)
logger.info(
"Generating model %s and it's save to %s", model_name, export_model_path
)
def main() -> None:
model_directory_path = get_fixtures_path()
generate_models(model_directory_path)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -2,7 +2,7 @@
import torch
from test.jit.fixtures_srcs.generate_models import ALL_MODULES
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import run_tests, TestCase
class TestUpgraderModelGeneration(TestCase):
@ -14,7 +14,9 @@ class TestUpgraderModelGeneration(TestCase):
f"The module {module_name} "
f"is not a torch.nn.module instance. "
f"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
f"and it's registered as an instance in ALL_MODULES in generated_models.py")
f"and it's registered as an instance in ALL_MODULES in generated_models.py",
)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -3,5 +3,7 @@ Define exceptions used in test_exception.py. We define them in a
separate file on purpose to make sure the fully qualified exception class name
is captured correctly in suce cases.
"""
class MyKeyError(KeyError):
pass

View File

@ -1,15 +1,18 @@
# Owner(s): ["oncall: jit"]
import torch
from torch._C import parse_ir
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir
import torch
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestAliasAnalysis(JitTestCase):
def test_becomes_wildcard_annotations(self):
@ -26,9 +29,13 @@ class TestAliasAnalysis(JitTestCase):
alias_db = graph.alias_db()
split_node = graph.findNode("aten::split")
# split input enters wildcard set, list initalized as containing wildcard set
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()))
self.assertTrue(
alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())
)
# because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())))
self.assertTrue(
alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))
)
def test_nested_list_construct_not_wildcard(self):
@torch.jit.script
@ -42,7 +49,9 @@ class TestAliasAnalysis(JitTestCase):
ten_construct = graph.findNode("aten::rand").output()
output = next(graph.outputs())
self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
self.assertFalse(alias_db.may_contain_alias(next(graph.inputs()), ten_construct))
self.assertFalse(
alias_db.may_contain_alias(next(graph.inputs()), ten_construct)
)
def test_recursive_calls(self):
@torch.jit.script
@ -108,7 +117,9 @@ class TestAliasAnalysis(JitTestCase):
class MultiTmpFile:
def __init__(self, N):
self.N = N
self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)]
self.ctxs = [
TemporaryFileName(mode="w", suffix=".py") for _ in range(N)
]
def __enter__(self):
return [x.__enter__() for x in self.ctxs]

View File

@ -3,18 +3,20 @@
import os
import sys
from typing import Any, Tuple
import torch
import torch.nn as nn
from typing import Any, Tuple
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
from typing import List
from torch import Tensor
from torch.jit import Future
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
class TestAsync(JitTestCase):
def test_async_python(self):
@ -51,8 +53,7 @@ class TestAsync(JitTestCase):
futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
for _ in range(3):
future = torch.jit.annotate(
Future[List[Tensor]],
torch.jit.fork(foo, x)
Future[List[Tensor]], torch.jit.fork(foo, x)
)
futures.append(future)
@ -85,7 +86,7 @@ class TestAsync(JitTestCase):
def test_async_script_capture(self):
class Mod(torch.jit.ScriptModule):
__constants__ = ['const']
__constants__ = ["const"]
def __init__(self):
super().__init__()
@ -139,7 +140,10 @@ class TestAsync(JitTestCase):
def test_async_script_no_script_mod(self):
x = torch.rand(3, 4)
with self.assertRaisesRegexWithHighlight(RuntimeError, 'cannot call a value', 'torch.jit._fork(x'):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "cannot call a value", "torch.jit._fork(x"
):
@torch.jit.script
def wait_script(x):
fut = torch.jit._fork(x)
@ -213,7 +217,7 @@ class TestAsync(JitTestCase):
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1))
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
]:
for wrapper in [
func,
@ -234,8 +238,8 @@ class TestAsync(JitTestCase):
return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
for wrapper in [
foo_script_args,
foo_script_kwargs,
foo_script_args,
foo_script_kwargs,
]:
self.assertEqual(wrapper(x1, x2), y_hat)
self.assertEqual(wrapper(x1, x2=x2), y_hat)
@ -255,7 +259,9 @@ class TestAsync(JitTestCase):
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@torch.jit.script_method
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
def forward(
self, x: Tensor
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
future1 = torch.jit._fork(self.traced, x)
future2 = torch.jit._fork(torch.neg, x)
@ -284,10 +290,16 @@ class TestAsync(JitTestCase):
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
self.assertGraphContainsExactly(
module.graph, kind="prim::fork", num_kind_nodes=2
)
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
self.assertGraphContainsExactly(
module.graph, kind="aten::neg", num_kind_nodes=1
)
self.assertGraphContainsExactly(
module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
)
y = torch.neg(x)
self.assertEqual(module(x), (y, y, y, y, x, x))
@ -311,19 +323,23 @@ class TestAsync(JitTestCase):
return torch.jit._wait(fut)
# no future
error_msg = 'The size.*must match the size of tensor'
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'x.t() + x'):
error_msg = "The size.*must match the size of tensor"
with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
foo(x)
# one future
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'torch.jit._fork(foo, x'):
with self.assertRaisesRegexWithHighlight(
Exception, error_msg, "torch.jit._fork(foo, x"
):
wait_script(x)
# two futures with a different error
x = torch.rand(3, 4, 5)
with self.assertRaisesRegexWithHighlight(Exception,
'expects a tensor with <= 2 dimensions',
'torch.jit._fork(wait_script, x'):
with self.assertRaisesRegexWithHighlight(
Exception,
"expects a tensor with <= 2 dimensions",
"torch.jit._fork(wait_script, x",
):
wait_script_nest(x)
def test_async_grad_guard_with_grad(self):
@ -381,9 +397,15 @@ class TestAsync(JitTestCase):
x = torch.rand(3, 4)
self.assertEqual(fn(x), traced(x))
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1)
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1)
self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True)
self.assertGraphContainsExactly(
traced.graph, kind="prim::fork", num_kind_nodes=1
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::wait", num_kind_nodes=1
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
)
def test_trace_fork_wait_leaking(self):
my_list = []
@ -397,9 +419,13 @@ class TestAsync(JitTestCase):
val = torch.jit._wait(fut)
return my_list[0]
with self.assertRaisesRegexWithHighlight(RuntimeError, 'did not have observable data dependence with trace inputs; '
'this probably indicates your program cannot be understood '
'by the tracer.', ''):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"did not have observable data dependence with trace inputs; "
"this probably indicates your program cannot be understood "
"by the tracer.",
"",
):
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
def test_trace_fork_wait_inline(self):
@ -413,9 +439,15 @@ class TestAsync(JitTestCase):
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
torch._C._jit_pass_inline_fork_wait(traced.graph)
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0)
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0)
self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2)
self.assertGraphContainsExactly(
traced.graph, kind="prim::fork", num_kind_nodes=0
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::wait", num_kind_nodes=0
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::add", num_kind_nodes=2
)
def test_trace_fork_wait_list_modulecalls(self):
def add_one(input):
@ -472,7 +504,10 @@ class TestAsync(JitTestCase):
self.checkTrace(TestModule(), (torch.randn(5, 5),))
def test_no_future_subtype_message(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, 'Future without a contained type', ''):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Future without a contained type", ""
):
@torch.jit.script
def forward(self, x):
futs = torch.jit.annotate(List[torch.jit.Future], [])
@ -481,6 +516,7 @@ class TestAsync(JitTestCase):
"""
Test that futures subtype each other properly.
"""
# Successful subtyping.
def returns_int(x: int) -> int:
return x + x + 1
@ -495,10 +531,11 @@ class TestAsync(JitTestCase):
# Unsuccessful subtyping.
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
"fut = returns_future_float(x"
RuntimeError,
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
"fut = returns_future_float(x",
):
def returns_future_float(x: int) -> torch.jit.Future[float]:
return torch.jit._fork(returns_int, (x))
@ -508,8 +545,9 @@ class TestAsync(JitTestCase):
return fut.wait()
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)

View File

@ -3,35 +3,40 @@
import torch
from torch.testing._internal.common_utils import TestCase
class TestAtenPow(TestCase):
def test_aten_pow_zero_negative_exponent(self):
'''
"""
1. Testing a = int, b = int
'''
"""
@torch.jit.script
def fn_int_int(a: int, b: int):
return a ** b
return a**b
# Existing correct behaviors of aten::pow
self.assertEqual(fn_int_int(2, 1), 2 ** 1)
self.assertEqual(fn_int_int(2, 0), 2 ** 0)
self.assertEqual(fn_int_int(2, 1), 2**1)
self.assertEqual(fn_int_int(2, 0), 2**0)
self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
self.assertEqual(fn_int_int(0, 2), 0 ** 1)
self.assertEqual(fn_int_int(0, 0), 0 ** 0)
self.assertEqual(fn_int_int(0, 2), 0**1)
self.assertEqual(fn_int_int(0, 0), 0**0)
# zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_int_int, 0, -2)
'''
"""
2. Testing a = int, b = float
'''
"""
@torch.jit.script
def fn_int_float(a: int, b: float):
return a ** b
return a**b
# Existing correct behaviors of aten::pow
self.assertEqual(fn_int_float(2, 2.5), 2 ** 2.5)
self.assertEqual(fn_int_float(2, 2.5), 2**2.5)
self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
@ -40,53 +45,57 @@ class TestAtenPow(TestCase):
self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
self.assertEqual(fn_int_float(0, 2.0), 0 ** 2.0)
self.assertEqual(fn_int_float(0, 0.5), 0 ** 0.5)
self.assertEqual(fn_int_float(0, 0.0), 0 ** 0.0)
self.assertEqual(fn_int_float(0, 2.0), 0**2.0)
self.assertEqual(fn_int_float(0, 0.5), 0**0.5)
self.assertEqual(fn_int_float(0, 0.0), 0**0.0)
self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
# zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)
'''
"""
3. Testing a = float, b = int
'''
"""
@torch.jit.script
def fn_float_int(a: float, b: int):
return a ** b
return a**b
# Existing correct behaviors of aten::pow
self.assertEqual(fn_float_int(2.5, 2), 2.5 ** 2)
self.assertEqual(fn_float_int(2.5, 2), 2.5**2)
self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
self.assertEqual(fn_float_int(2.5, 0), 2.5 ** 0)
self.assertEqual(fn_float_int(-2.5, 2), 2.5 ** 2)
self.assertEqual(fn_float_int(2.5, 0), 2.5**0)
self.assertEqual(fn_float_int(-2.5, 2), 2.5**2)
self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
self.assertEqual(fn_float_int(0.0, 2), 0 ** 2)
self.assertEqual(fn_float_int(0.0, 0), 0 ** 0)
self.assertEqual(fn_float_int(0.0, 2), 0**2)
self.assertEqual(fn_float_int(0.0, 0), 0**0)
self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
# zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)
'''
"""
4. Testing a = float, b = float
'''
"""
@torch.jit.script
def fn_float_float(a: float, b: float):
return a ** b
return a**b
# Existing correct behaviors of aten::pow
self.assertEqual(fn_float_float(2.5, 2.0), 2.5 ** 2.0)
self.assertEqual(fn_float_float(2.5, 2.0), 2.5**2.0)
self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
self.assertEqual(fn_float_float(2.5, 0.0), 2.5 ** 0.0)
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5 ** 2.0)
self.assertEqual(fn_float_float(2.5, 0.0), 2.5**0.0)
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5**2.0)
self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
self.assertEqual(fn_float_float(0.0, 2.0), 0.0 ** 2.0)
self.assertEqual(fn_float_float(0.0, 0.0), 0.0 ** 0.0)
self.assertEqual(fn_float_float(0.0, 2.0), 0.0**2.0)
self.assertEqual(fn_float_float(0.0, 0.0), 0.0**0.0)
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
# zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)

View File

@ -1,20 +1,22 @@
# Owner(s): ["oncall: jit"]
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
import torch
from typing import NamedTuple, Tuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGetDefaultAttr(JitTestCase):
def test_getattr_with_default(self):
class A(torch.nn.Module):
def __init__(self):
super().__init__()
@ -22,7 +24,7 @@ class TestGetDefaultAttr(JitTestCase):
def forward(self, x):
y = getattr(self, "init_attr_val") # noqa: B009
w : list[float] = [1.0]
w: list[float] = [1.0]
z = getattr(self, "missing", w) # noqa: B009
z.append(y)
return z
@ -32,7 +34,7 @@ class TestGetDefaultAttr(JitTestCase):
graph = torch.jit.script(A()).graph
# The "init_attr_val" attribute exists
FileCheck().check("prim::GetAttr[name=\"init_attr_val\"]").run(graph)
FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
# The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
FileCheck().check_not("missing").run(graph)
# instead the getattr call will emit the default value, which is a list with one float element
@ -46,7 +48,11 @@ class TestGetDefaultAttr(JitTestCase):
y: torch.Tensor
def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
return getattr(x, "x", "fdsa"), getattr(x, "y", torch.ones((3, 3))), getattr(x, "z", 7)
return (
getattr(x, "x", "fdsa"),
getattr(x, "y", torch.ones((3, 3))),
getattr(x, "z", 7),
)
inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
ref = fn(inp)

View File

@ -1,10 +1,11 @@
# Owner(s): ["oncall: jit"]
from typing import List
import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase
from typing import List
@skipIfTorchDynamo()
@ -119,7 +120,6 @@ class TestAutodiffJit(JitTestCase):
self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_autodiff_requires_grad_nograd(self):
@torch.jit.ignore
def python_fn(x):

View File

@ -3,26 +3,38 @@
import os
import sys
import unittest
from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, \
num_profiled_runs, enable_profiling_mode_for_profiling_tests
from torch.testing._internal.common_jit import check_against_reference
import torch
from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
num_profiled_runs,
ProfilingMode,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, disable_autodiff_subgraph_inlining
from typing import List, Optional, Tuple
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import (
disable_autodiff_subgraph_inlining,
JitTestCase,
)
from typing import List, Tuple, Optional
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients")
@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
)
class TestAutodiffSubgraphSlicing(JitTestCase):
# TODO: It is better if we can test directly on graphs instead of the current
# end-to-end fashion.
@ -35,11 +47,17 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return ge.graph_for(*inputs)
def assertGraphSize(self, graph, size):
nodes = list(filter(lambda n: (n.kind() != "prim::BailOut" and
n.kind() != "prim::BailoutTemplate" and
n.kind() != "prim::TypeCheck" and
n.kind() != "prim::RequiresGradCheck"),
graph.nodes()))
nodes = list(
filter(
lambda n: (
n.kind() != "prim::BailOut"
and n.kind() != "prim::BailoutTemplate"
and n.kind() != "prim::TypeCheck"
and n.kind() != "prim::RequiresGradCheck"
),
graph.nodes(),
)
)
self.assertEqual(len(list(nodes)), size)
def test_chunk_constant_script_ad(self):
@ -52,16 +70,21 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
with disable_autodiff_subgraph_inlining():
with enable_profiling_mode_for_profiling_tests():
output = func(input, profile_and_replay=True)
FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input))
FileCheck().check_not("prim::DifferentiableGraph").run(
func.graph_for(input)
)
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "This threshold is only valid for Profiling Executor")
@unittest.skipIf(
GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"This threshold is only valid for Profiling Executor",
)
def test_diff_graph_inline_threshold(self):
with enable_profiling_mode_for_profiling_tests():
NUM_RUNS = 1
with num_profiled_runs(NUM_RUNS):
@torch.jit.script
def foo(x):
# two nodes should be fused
# see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
return torch.sigmoid(torch.sigmoid(x))
@ -78,12 +101,16 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
bar(input)
bar(input)
self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1)
self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
self.assertGraphContainsExactly(
foo.graph_for(input), "prim::DifferentiableGraph", 1
)
self.assertGraphContainsExactly(
bar.graph_for(input), "prim::DifferentiableGraph", 0
)
def test_bias_as_module_attr(self):
with enable_profiling_mode_for_profiling_tests():
class M(torch.nn.Module):
def __init__(self, has_bias):
super().__init__()
@ -99,19 +126,40 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
scripted_no_bias(x, x)
scripted_no_bias(x, x)
has_bias = M(True)
check_against_reference(self, scripted_no_bias, no_bias, lambda x: x, (x, x,), check_types=False)
check_against_reference(
self,
scripted_no_bias,
no_bias,
lambda x: x,
(
x,
x,
),
check_types=False,
)
scripted_has_bias = torch.jit.script(has_bias)
scripted_has_bias(x, x)
scripted_has_bias(x, x)
scripted_has_bias(x, x)
check_against_reference(self, scripted_has_bias, has_bias, lambda x: x, (x, x,), check_types=False)
check_against_reference(
self,
scripted_has_bias,
has_bias,
lambda x: x,
(
x,
x,
),
check_types=False,
)
def test_constructed_bias(self):
with enable_profiling_mode_for_profiling_tests():
def method1(x, weight, b1, b2):
bias = b1 * b2
return torch.nn.functional.linear(x, weight, bias)
N = 10
x = torch.rand(N, N, requires_grad=True)
weight = torch.rand(N, N, requires_grad=True)
@ -119,35 +167,58 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
b2 = torch.rand(N, N, requires_grad=True)
scripted = self.checkScript(method1, (x, weight, b1, b2))
# check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, b1, b2),
check_types=False,
)
def test_bias_as_arg(self):
with enable_profiling_mode_for_profiling_tests():
def method1(x, weight, bias: Optional[torch.Tensor]):
return torch.nn.functional.linear(x, weight, bias).relu() + 2
N = 10
x = torch.rand(N, N, requires_grad=True)
weight = torch.rand(N, N, requires_grad=True)
bias = None
scripted = self.checkScript(method1, (x, weight, bias))
# check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, bias),
check_types=False,
)
bias = torch.rand(N, N, requires_grad=True)
scripted = self.checkScript(method1, (x, weight, bias))
# check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, bias),
check_types=False,
)
def test_requires_grad_for_tensor_list(self):
with enable_profiling_mode_for_profiling_tests():
# output & var_list[0] should have requires_grad set to True
def func(input0: torch.Tensor, input1: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
def func(
input0: torch.Tensor, input1: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
var_list = [input0, input1]
var = torch.cat(var_list)
output = var + 1.0
return output, var_list
jit_f = torch.jit.script(func)
input0 = torch.randn((2,), requires_grad=True)
input1 = torch.randn((2,))
@ -158,12 +229,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
assert output_ref[1][0].requires_grad == output[1][0].requires_grad
assert output_ref[1][1].requires_grad == output[1][1].requires_grad
@unittest.skip("disable until we property handle tensor lists with undefined gradients")
@unittest.skip(
"disable until we property handle tensor lists with undefined gradients"
)
def test_differentiable_graph_ops_requires_grad(self):
x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
y = torch.randn(8, 2, dtype=torch.float)
def t(x : torch.Tensor, y : torch.Tensor, flag : bool):
def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
o = x + 1.0
o1 = torch.relu(o)
o = y + 1.5
@ -186,13 +259,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return o1, o2, o3, oo1, oo2, oo3
with enable_profiling_mode_for_profiling_tests():
t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, False)
jit_o = t_jit(x, y, False)
o = t(x, y, False)
FileCheck().check("prim::DifferentiableGraph").run(t_jit.graph_for(x, y, False))
FileCheck().check("prim::DifferentiableGraph").run(
t_jit.graph_for(x, y, False)
)
# validate the differentiableGraphOps are marking proper requires_grad
for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
@ -204,22 +278,28 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
self.assertEqual(oo, jit_oo)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Simple Executor doesn't support gradients")
@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.PROFILING,
"Simple Executor doesn't support gradients",
)
def test_prune_grad(self):
@torch.jit.script
def t(input, bias):
return torch.nn.functional.relu(input + bias)
input = torch.randn(2, 8, requires_grad=True)
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
NUM_PROFILED_RUNS = 1
with num_profiled_runs(NUM_PROFILED_RUNS):
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
for x in range(WARMUP):
o = t(input, bias)
o.sum().backward()
fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
bwd_graph = list(fwd_plan.code.grad_executor_states()[0].execution_plans.values())[0].graph
bwd_graph = list(
fwd_plan.code.grad_executor_states()[0].execution_plans.values()
)[0].graph
tup = next(bwd_graph.outputs())
self.assertEqual(len(list(tup.node().inputs())), 1)
@ -233,7 +313,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_simple_no_merge(self):
# o: autodiff supported. x: not autodiff supported.
@ -245,8 +325,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
g_str = str(graph)
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(g_str[0:g_str.find("return")])
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(
g_str[0 : g_str.find("return")]
)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_does_not_merge_unrelated(self):
# o o
@ -258,7 +340,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_merges_without_cycles(self):
# o --> o --> o
@ -273,7 +355,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_merges_dense(self):
# o o
@ -290,7 +372,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
self.assertGraphSize(graph, 2)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_does_not_create_cycles(self):
# o --> x --> o
@ -303,7 +385,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_merges_up(self):
# o --> x o
@ -317,8 +399,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
g_str = str(graph)
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")])
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_merges_down(self):
# o x --> o
@ -335,8 +417,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
# add moved down
g_str = str(graph)
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")])
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_respects_lexical_scoping(self):
def fn(x, k):
@ -346,12 +428,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
z = y * k
return z, k
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
# We should not have combined the two multiplications into
# the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 3)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3)
def test_merge_respects_aliasing(self):
def fn(x, k, cond):
@ -368,15 +448,13 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
# z2 did did not get merged into the subgraph
FileCheck().check("prim::If").check("aten::select").check_next("aten::select")\
.check_next("aten::add_").check("Differentiable").run(graph)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
FileCheck().check("prim::If").check("aten::select").check_next(
"aten::select"
).check_next("aten::add_").check("Differentiable").run(graph)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_aliased_outputs(self):
with enable_profiling_mode_for_profiling_tests():
# Case 1: aliasing between relu and t
# is within a DifferentiableGraph. It should be valid
# to merge both split_with_sizes in relu in one graph
@ -389,9 +467,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("with prim::DifferentiableGraph") \
.check("aten::relu").check("aten::t") \
.run(graph)
FileCheck().check("with prim::DifferentiableGraph").check(
"aten::relu"
).check("aten::t").run(graph)
# Case 2: aliasing between relu and split_with_sizes
# are both outputs of a Diff graph. It should be invalid
@ -410,11 +488,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \
.check("with prim::DifferentiableGraph") \
.check("Tensor = aten::relu") \
.check_not("aten::split_with_sizes") \
.run(graph)
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
graph
)
# Case 3: two aliased nodes in a graph.
# Both `split_with_sizes` should be unfused
@ -432,11 +510,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \
.check("with prim::DifferentiableGraph") \
.check("Tensor = aten::relu") \
.check_not("aten::split_with_sizes") \
.run(graph)
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
graph
)
# Case 4: the aliased output has a descendant
# Both should be unfused. Note, %3 comes before %2
@ -454,11 +532,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \
.check("with prim::DifferentiableGraph") \
.check("Tensor = aten::relu") \
.check_not("aten::t") \
.run(graph)
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
# Case 5: multiple aliased groups
# Both should be unfused. Note, %3 comes before %2
@ -478,11 +554,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \
.check("with prim::DifferentiableGraph") \
.check("Tensor = aten::relu") \
.check_not("aten::t") \
.run(graph)
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
def test_has_profiled_info_aliasing_outputs(self):
# The expectation is that CallFunction will prevent the final profile node from
@ -511,9 +585,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
output = outputs[0]
self.assertEqual(False, output.requiresGrad())
FileCheck().check("= prim::DifferentiableGraph") \
.check("with prim::DifferentiableGraph") \
.check(" = aten::relu") \
.check("requires_grad=0") \
.check("aten::relu") \
.run(graph)
FileCheck().check("= prim::DifferentiableGraph").check(
"with prim::DifferentiableGraph"
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)

View File

@ -1,18 +1,19 @@
# Owner(s): ["oncall: jit"]
import io
import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import make_global
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from torch._awaits import _Await as Await
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestAwait(JitTestCase):
def test_await_python(self):
def foo(x: int) -> int:
return x + 13
aw: Await[int] = torch.jit._awaitable(foo, 13)
self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
nw = torch.jit._awaitable_nowait(33)
@ -22,6 +23,7 @@ class TestAwait(JitTestCase):
def test_await_type_python(self):
def foo() -> Tensor:
return torch.randn()
awaits = torch.jit.annotate(List[Await[Tensor]], [])
awaits.append(torch.jit._awaitable(foo))
@ -82,9 +84,7 @@ class TestAwait(JitTestCase):
self.assertTrue(torch.allclose(torch.eye(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_await_class_arg(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self.__a = a
@ -104,6 +104,7 @@ class TestAwait(JitTestCase):
_a = torch.eye(2)
c2_t = torch.jit._awaitable_wait(aw)
return _a + c2_t + x
inp = torch.zeros(2)
sm = torch.jit.script(fn)
@ -120,7 +121,6 @@ class TestAwait(JitTestCase):
self._a = a
self._b = b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
@ -143,7 +143,6 @@ class TestAwait(JitTestCase):
self.assertTrue(torch.allclose(script_out, out))
def test_await_class_return(self):
class C:
__slots__ = ["a", "b"]
@ -151,7 +150,6 @@ class TestAwait(JitTestCase):
self.a = a
self.b = b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
@ -175,7 +173,9 @@ class TestAwait(JitTestCase):
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=1)
self.assertGraphContainsExactly(
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
)
def test_await_getattr_implicit_convertion(self):
class C:
@ -186,7 +186,6 @@ class TestAwait(JitTestCase):
def b(self):
return self._b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
@ -212,10 +211,11 @@ class TestAwait(JitTestCase):
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=2)
self.assertGraphContainsExactly(
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
)
def test_await_nested(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self.__a = a
@ -250,6 +250,7 @@ class TestAwait(JitTestCase):
def __init__(self, v):
self.parent = torch.jit.annotate(Optional[Tree], None)
self.v = v
make_global(Tree)
def delayed(t: Tree):
@ -275,12 +276,15 @@ class TestAwait(JitTestCase):
sm = torch.jit.script(main)
out = main(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
self.assertTrue(
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
)
self.assertTrue(torch.allclose(script_out, out))
def test_await_eager_lazy(self):
def delayed(x: Tensor) -> Tensor:
return 2 * (x + 1)
t = torch.ones(2, dtype=torch.int64)
aw = torch.jit._awaitable(delayed, t)
self.assertTrue(isinstance(aw, torch._C._Await))
@ -302,7 +306,9 @@ class TestAwait(JitTestCase):
script_out_aw = sm(inp)
script_out = torch.jit._awaitable_wait(script_out_aw)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
self.assertTrue(
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
)
self.assertTrue(torch.allclose(script_out, out))
def test_jit_trace(self):

View File

@ -3,10 +3,10 @@
import os
import sys
import unittest
from pathlib import Path
import torch
import torch._C
from pathlib import Path
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
# hacky way to skip these tests in fbcode:
@ -15,9 +15,11 @@ from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
# it sees tests but then fails when it tries to actuall run them.
if not IS_FBCODE:
from test_nnapi import TestNNAPI
HAS_TEST_NNAPI = True
else:
from torch.testing._internal.common_utils import TestCase as TestNNAPI
HAS_TEST_NNAPI = False
@ -39,10 +41,14 @@ without the delegate API.
"""
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
torch_root = Path(__file__).resolve().parent.parent.parent
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
@skipIfTorchDynamo("weird py38 failures")
@unittest.skipIf(not os.path.exists(lib_path),
"Skipping the test as libnnapi_backend.so was not found")
@unittest.skipIf(
not os.path.exists(lib_path),
"Skipping the test as libnnapi_backend.so was not found",
)
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
class TestNnapiBackend(TestNNAPI):
def setUp(self):
@ -89,35 +95,44 @@ method_compile_spec must use the following format:
# No forward key
compile_spec = {"backward": {"inputs": args}}
with self.assertRaisesRegex(RuntimeError, "method_compile_spec does not contain the \"forward\" key." + errorMsgTail):
with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No dictionary under the forward key
compile_spec = {"forward": 1}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
"under it's \"forward\" key."
+ errorMsgTail):
with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain a dictionary with an "inputs" key, '
'under it\'s "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No inputs key (in the dictionary under the forward key)
compile_spec = {"forward": {"not inputs": args}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
"under it's \"forward\" key."
+ errorMsgTail):
with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain a dictionary with an "inputs" key, '
'under it\'s "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No Tensor or TensorList under the inputs key
compile_spec = {"forward": {"inputs": 1}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
+ errorMsgTail):
with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
+ errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
compile_spec = {"forward": {"inputs": [1]}}
with self.assertRaisesRegex(RuntimeError,
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
+ errorMsgTail):
with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
+ errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec)
def tearDown(self):

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: jit"]
from torch.testing._internal.jit_utils import JitTestCase
import io
import os
import sys
@ -8,18 +7,20 @@ import unittest
import torch
import torch._C
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
TEST_WITH_ROCM,
skipIfRocm,
find_library_location,
TEST_WITH_ROCM,
)
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
@ -33,7 +34,9 @@ if __name__ == "__main__":
def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_backend("test_backend", module, {"forward": method_compile_spec})
return torch._C._jit_to_backend(
"test_backend", module, {"forward": method_compile_spec}
)
def to_test_backend_multi(module, method_compile_spec):
@ -63,8 +66,10 @@ class BasicModule(torch.nn.Module):
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCase(JitTestCase):
"""
A common base class for JIT backend tests that contains common utility
@ -73,7 +78,7 @@ class JitBackendTestCase(JitTestCase):
def setUp(self):
super().setUp()
lib_file_path = find_library_location('libjitbackend_test.so')
lib_file_path = find_library_location("libjitbackend_test.so")
torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up three variables in their setUp methods:
# module - a regular, Python version of the module being tested
@ -154,13 +159,17 @@ class BasicModuleTest(JitBackendTestCase):
self.test_execution()
# Save the compile spec to compare against the version retrieved after loading.
pre_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec")
pre_compile_spec = self.lowered_module.__getattr__(
"__loweredModule__"
).__getattr__("__method_compile_spec")
# Save and load the lowered module.
self.save_load()
# Get the compile spec after loading.
post_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec")
post_compile_spec = self.lowered_module.__getattr__(
"__loweredModule__"
).__getattr__("__method_compile_spec")
# Compile specs should match.
self.assertEqual(pre_compile_spec, post_compile_spec)
@ -195,9 +204,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
input = torch.randn(5)
# Test exception is thrown.
with self.assertRaisesRegexWithHighlight(Exception,
r"Backend is not available.",
"raise Exception(\"Backend is not available.\""):
with self.assertRaisesRegexWithHighlight(
Exception,
r"Backend is not available.",
'raise Exception("Backend is not available."',
):
backend_method = self.lowered_module.__getattr__("forward")
backend_output = backend_method(*(input, input))
@ -207,9 +218,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
buffer = io.BytesIO()
torch.jit.save(self.lowered_module, buffer)
buffer.seek(0)
with self.assertRaisesRegexWithHighlight(Exception,
r"Backend is not available.",
"raise Exception(\"Backend is not available.\""):
with self.assertRaisesRegexWithHighlight(
Exception,
r"Backend is not available.",
'raise Exception("Backend is not available."',
):
imported = torch.jit.load(buffer)
@ -218,6 +231,7 @@ class NestedModuleTest(JitBackendTestCase):
Tests for NestedModule that check that a module lowered to a backend can be used
as a submodule.
"""
class NestedModule(torch.nn.Module):
"""
A Module with one submodule that is used to test that lowered Modules
@ -237,7 +251,9 @@ class NestedModuleTest(JitBackendTestCase):
# Both modules in self.module are regular Python modules.
self.module = NestedModuleTest.NestedModule(BasicModule())
# Both modules in self.scripted_module are ScriptModules.
self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule()))
self.scripted_module = torch.jit.script(
NestedModuleTest.NestedModule(BasicModule())
)
# First, script another instance of NestedModule with share_types=False so that it can be
# selectively lowered without modifying the type of self.scripted_module.
@ -246,7 +262,9 @@ class NestedModuleTest(JitBackendTestCase):
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
)
# self.lowered_module is a ScriptModule, but its submodule is a lowered module.
self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module))
self.lowered_module = torch.jit.script(
NestedModuleTest.NestedModule(lowered_module)
)
def test_execution(self):
# Test execution with backend against Python and JIT.
@ -270,6 +288,7 @@ class SelectiveLoweringTest(JitBackendTestCase):
"""
Tests for the selective lowering API.
"""
class OuterModule(torch.nn.Module):
def __init__(self, sub1, sub2, other):
super().__init__()
@ -299,7 +318,10 @@ class SelectiveLoweringTest(JitBackendTestCase):
MiddleModule = SelectiveLoweringTest.MiddleModule
def script_without_type_sharing(mod):
return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False)
return torch.jit._recursive.create_script_module(
mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
)
# Create Python, JIT and backend versions of a hierarchy that looks like this:
# --------- OuterModule --------
# | | |
@ -308,13 +330,28 @@ class SelectiveLoweringTest(JitBackendTestCase):
# BasicModule BasicModule BasicModule
#
# Two BasicModules will be lowered and the third will not.
self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))
self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule(
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())))
self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule(
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())))
self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [
"sub1.submodule", "sub2.submodule"])
self.module = OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
self.scripted_module = script_without_type_sharing(
OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
)
self.lowered_module = script_without_type_sharing(
OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
)
self.lowered_module = to_test_backend_selective(
self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
)
def test_execution(self):
input = torch.randn(5)
@ -335,93 +372,93 @@ class SelectiveLoweringTest(JitBackendTestCase):
"""
# Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
# calling the lowered module directly.
FileCheck() \
.check("OuterModule") \
.check("BasicModule") \
.run(self.scripted_module.graph)
FileCheck() \
.check("OuterModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.graph)
FileCheck().check("OuterModule").check("BasicModule").run(
self.scripted_module.graph
)
FileCheck().check("OuterModule").check_not(
"__torch__.torch.classes.__backends__.test_backend"
).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
# Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
FileCheck() \
.check("MiddleModule") \
.check("BasicModule") \
.check_not("LoweredWrapper.test_backend") \
.run(self.scripted_module.sub1.graph)
FileCheck() \
.check("MiddleModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.sub1.graph)
FileCheck().check("MiddleModule").check("BasicModule").check_not(
"LoweredWrapper.test_backend"
).run(self.scripted_module.sub1.graph)
FileCheck().check("MiddleModule").check_not(
"__torch__.torch.classes.__backends__.test_backend"
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
FileCheck() \
.check("MiddleModule") \
.check("BasicModule") \
.check_not("LoweredWrapper.test_backend") \
.run(self.scripted_module.sub2.graph)
FileCheck() \
.check("MiddleModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.sub2.graph)
FileCheck().check("MiddleModule").check("BasicModule").check_not(
"LoweredWrapper.test_backend"
).run(self.scripted_module.sub2.graph)
FileCheck().check("MiddleModule").check_not(
"__torch__.torch.classes.__backends__.test_backend"
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
# Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
# __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
# the TorchBind class for executing functions on the test JIT backend.
FileCheck() \
.check("LoweredModule.test_backend") \
.check("__torch__.torch.classes.__backends__.test_backend") \
.run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
FileCheck().check("LoweredModule.test_backend").check(
"__torch__.torch.classes.__backends__.test_backend"
).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
FileCheck() \
.check("LoweredModule.test_backend") \
.check("__torch__.torch.classes.__backends__.test_backend") \
.run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
FileCheck().check("LoweredModule.test_backend").check(
"__torch__.torch.classes.__backends__.test_backend"
).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
# Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
FileCheck() \
.check("MiddleModule") \
.check("BasicModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check_not("LoweredWrapper.test_backend") \
.run(self.scripted_module.other.graph)
FileCheck() \
.check("BasicModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check_not("LoweredModule.test_backend") \
.run(self.scripted_module.other.submodule.graph)
FileCheck().check("MiddleModule").check("BasicModule").check_not(
"__torch__.torch.classes.__backends__.test_backend"
).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
FileCheck().check("BasicModule").check_not(
"__torch__.torch.classes.__backends__.test_backend"
).check_not("LoweredModule.test_backend").run(
self.scripted_module.other.submodule.graph
)
def test_errors(self):
"""
Check errors associated with selective lowering.
"""
# Check error messages thrown when attempting to lower something that is not a ScriptModule.
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Object .* is not a ScriptModule", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Object .* is not a ScriptModule", ""
):
to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
MiddleModule = SelectiveLoweringTest.MiddleModule
mod = MiddleModule(BasicModule())
mod.new_attr = 3
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute named new_attr is not a Module", ""):
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"])
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute named new_attr is not a Module", ""
):
to_test_backend_selective(
torch.jit.script(mod), {"forward": ""}, ["new_attr"]
)
# Check error message thrown when module hierarchy doesn't have unique types.
OuterModule = SelectiveLoweringTest.OuterModule
mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))
mod = OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
with self.assertRaisesRegexWithHighlight(RuntimeError,
r"Selective lowering is only supported for module hierarchies with unique types",
""):
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"])
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"Selective lowering is only supported for module hierarchies with unique types",
"",
):
to_test_backend_selective(
torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
)
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class TestBackends(JitTestCase):
"""
This class wraps and invokes all subclasses of JitBackendTestCase so that each one
@ -461,6 +498,7 @@ class TestBackends(JitTestCase):
def test_errors(self):
self.selective_lowering_test.test_errors()
"""
Unit Tests for backend with compiler
This test case and the existing TestBackends are separate because they cover different aspects.
@ -468,6 +506,8 @@ The actual backend implementation in this test is different.
It has a simple demo compiler to test the end-to-end flow in mobile.
However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
"""
class BasicModuleAdd(torch.nn.Module):
"""
A simple add Module used to test to_backend lowering machinery.
@ -476,9 +516,12 @@ class BasicModuleAdd(torch.nn.Module):
def forward(self, x, h):
return x + h
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
@unittest.skipIf(
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCaseWithCompiler(JitTestCase):
"""
A common base class for JIT backend tests with compilers that contains common utility
@ -487,7 +530,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
def setUp(self):
super().setUp()
lib_file_path = find_library_location('libbackend_with_compiler.so')
lib_file_path = find_library_location("libbackend_with_compiler.so")
torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up four variables in their setUp methods:
# module - a regular, Python version of the module being tested
@ -524,6 +567,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
"""
pass
class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
"""
Tests for BasicModuleAdd.
@ -541,7 +585,8 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
},
}
self.lowered_module = torch._C._jit_to_backend(
"backend_with_compiler_demo", self.scripted_module, compile_spec)
"backend_with_compiler_demo", self.scripted_module, compile_spec
)
# Create mobile version of BasicModuleAdd
buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -552,6 +597,7 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
input = torch.ones(1, dtype=torch.float)
self.check_forward((input, input))
class ErrorMessagesWithCompiler(JitBackendTestCase):
"""
Tests for errors that occur with compiler, specifically:
@ -562,22 +608,31 @@ class ErrorMessagesWithCompiler(JitBackendTestCase):
"""
A module with an operator that is not supported.
"""
def forward(self, x, h):
return x * h
self._loweredmodule.forward()
def test_errors(self):
scripted_module_n = torch.jit.script(ErrorMessagesWithCompiler.ModuleNotSupported())
scripted_module_n = torch.jit.script(
ErrorMessagesWithCompiler.ModuleNotSupported()
)
# Test exception is thrown when lowering a module with an unsupported operator
with self.assertRaisesRegexWithHighlight(RuntimeError,
# Special escape characters are replaced with '.'
r"""The node of aten::mul is not supported in this compiler. .*
with self.assertRaisesRegexWithHighlight(
RuntimeError,
# Special escape characters are replaced with '.'
r"""The node of aten::mul is not supported in this compiler. .*
def forward.self, x, h.:
return x . h
~~~~~ <--- HERE
self._loweredmodule.forward..
""", ""):
lowered_module_n = torch._C._jit_to_backend("backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}})
""",
"",
):
lowered_module_n = torch._C._jit_to_backend(
"backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
)
class CompModuleTestWithCompiler(JitBackendTestCase):
"""
@ -588,6 +643,7 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
"""
A simple subtraction Module to be used in CompModule.
"""
def forward(self, x, h):
return x - h
@ -617,14 +673,19 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
},
}
lowered_add = torch._C._jit_to_backend(
"backend_with_compiler_demo", torch.jit.script(BasicModuleAdd()), compile_spec)
"backend_with_compiler_demo",
torch.jit.script(BasicModuleAdd()),
compile_spec,
)
lowered_sub = torch._C._jit_to_backend(
"backend_with_compiler_demo",
torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
{"forward": {"": ""}}
{"forward": {"": ""}},
)
self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
self.scripted_module = torch.jit.script(CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub))
self.scripted_module = torch.jit.script(
CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
)
# No backend version of CompModule currently, so this is filler.
self.lowered_module = self.scripted_module
# Create a mobile version of CompModule from JIT version
@ -640,9 +701,12 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
# Test forward.
self.check_function("forward", (input1, input2, input2))
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test")
@unittest.skipIf(
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class TestBackendsWithCompiler(JitTestCase):
"""
This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
@ -711,7 +775,6 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
y = s * (c * d)
return y
def setUp(self):
super().setUp()
@ -728,6 +791,7 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
# Test forward.
self.check_function("forward", (a, b, s))
class AddedAttributesTest(JitBackendTestCase):
"""
Tests for adding attributes to a model after lowering.
@ -747,11 +811,19 @@ class AddedAttributesTest(JitBackendTestCase):
input = [(torch.ones(5),)]
pre_bundled = self.lowered_module(*input[0])
# Attach bundled inputs which adds several attributes and functions to the model
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input) # noqa: F821
post_bundled = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
self.lowered_module = (
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
lowered_module, input # noqa: F821
)
)
post_bundled = self.lowered_module(
*self.lowered_module.get_all_bundled_inputs()[0]
)
# Save and load the lowered module.
self.save_load()
# Use bundled after save and load to prove its preserved
post_load = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
post_load = self.lowered_module(
*self.lowered_module.get_all_bundled_inputs()[0]
)
self.assertEqual(pre_bundled, post_bundled)
self.assertEqual(post_bundled, post_load)

View File

@ -56,7 +56,6 @@ class TestBatchMM(JitTestCase):
actual = test_batch_mm_scripted(*tensors)
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
def test_batch_mm_permitted_mutation(self):
def test_batch_mm(
T1: torch.Tensor,

View File

@ -1,8 +1,8 @@
# Owner(s): ["oncall: jit"]
import inspect
import os
import sys
import inspect
import unittest
from typing import Dict, List
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestBuiltins(JitTestCase):
@ -86,24 +88,27 @@ class TestBuiltins(JitTestCase):
self.checkScript(fn, ([1, 2, 3],))
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script
def fn(x):
a = x ** 2
a = x**2
del a
return a # noqa: F821
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script
def fn(x):
a = x ** 2
a = x**2
if a:
del a
return a
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"):
@torch.jit.script
def fn(x):
a = x ** 2
a = x**2
del b # noqa: F821
return a
@ -124,7 +129,7 @@ class TestBuiltins(JitTestCase):
self.assertEqual(py_out, jit_out)
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
del x['hi'], x['there']
del x["hi"], x["there"]
return x
py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
@ -137,7 +142,7 @@ class TestTensorBuiltins(JitTestCase):
def should_keep(tensor, name):
if inspect.isroutine(getattr(tensor, name)):
return False
if name.startswith('_'):
if name.startswith("_"):
return False
return True
@ -145,8 +150,8 @@ class TestTensorBuiltins(JitTestCase):
keys = dir(tensor)
# real and imag are only implemented for complex tensors.
self.assertRaises(RuntimeError, lambda: should_keep(tensor, 'imag'))
keys.remove('imag')
self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag"))
keys.remove("imag")
properties = [p for p in keys if should_keep(tensor, p)]
@ -158,16 +163,16 @@ class TestTensorBuiltins(JitTestCase):
EQUALITY_MISMATCH = {
# TorchScript doesn't have real enums so they return an int instead
# of the actual value
'dtype',
'layout',
"dtype",
"layout",
}
MISSING_PROPERTIES = {
'grad_fn',
"grad_fn",
# This is an undocumented property so it's not included
"output_nr",
# This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways
'names',
"names",
}
for p in properties:
@ -232,7 +237,8 @@ class TestTensorBuiltins(JitTestCase):
def func():
c = 1
return c.add(1)
with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
torch.jit.script(func)
# testing implicit conversion of tensors to scalars to match function arguments
@ -265,10 +271,12 @@ class TestTensorBuiltins(JitTestCase):
x = torch.zeros(10)
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
tensors = [torch.tensor(1.1),
torch.tensor(1.1, requires_grad=True),
torch.tensor(0),
torch.tensor([2])]
tensors = [
torch.tensor(1.1),
torch.tensor(1.1, requires_grad=True),
torch.tensor(0),
torch.tensor([2]),
]
script_funs = [tensor_to_int_script, tensor_to_float_script]
funs = [tensor_to_int, tensor_to_float]
@ -286,4 +294,6 @@ class TestTensorBuiltins(JitTestCase):
# assert result or exception equal for each (function, inputs)
for tensor in tensors:
for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
self.assertEqual(
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
)

View File

@ -4,24 +4,28 @@ import io
import os
import sys
import unittest
from typing import Any
import torch
import torch.nn as nn
from torch.testing import FileCheck
from typing import Any
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global
from typing import Dict, Iterable, List, Optional, Tuple
import torch.testing._internal.jit_utils
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
from typing import List, Tuple, Iterable, Optional, Dict
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestClassType(JitTestCase):
def test_reference_semantics(self):
@ -29,6 +33,7 @@ class TestClassType(JitTestCase):
Test that modifications made to a class instance in TorchScript
are visible in eager.
"""
class Foo:
def __init__(self, a: int):
self.a = a
@ -92,12 +97,12 @@ class TestClassType(JitTestCase):
pass
def __contains__(self, key: str) -> bool:
return key == 'hi'
return key == "hi"
@torch.jit.script
def fn():
foo = FooTest()
return 'hi' in foo, 'no' in foo
return "hi" in foo, "no" in foo
self.assertEqual(fn(), (True, False))
@ -118,7 +123,10 @@ class TestClassType(JitTestCase):
self.assertEqual(fn(1), 3)
def test_set_attr_type_mismatch(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"
):
@torch.jit.script
class FooTest:
def __init__(self, x):
@ -126,7 +134,10 @@ class TestClassType(JitTestCase):
self.foo = 10 # should error since int != Tensor
def test_get_attr_not_initialized(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.asdf"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", "self.asdf"
):
@torch.jit.script
class FooTest:
def __init__(self, x):
@ -136,7 +147,10 @@ class TestClassType(JitTestCase):
return self.asdf # asdf isn't an attr
def test_set_attr_non_initialized(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"
):
@torch.jit.script
class FooTest:
def __init__(self, x):
@ -153,12 +167,16 @@ class TestClassType(JitTestCase):
Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'.
"""
with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""):
@torch.jit.script
def FooTest(x):
return torch.nn.functional.interpolate(x, 'bad')
return torch.nn.functional.interpolate(x, "bad")
def test_type_annotations(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Expected a value of type \'bool", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Expected a value of type 'bool", ""
):
@torch.jit.script # noqa: B903
class FooTest: # noqa: B903
def __init__(self, x: bool) -> None:
@ -171,7 +189,10 @@ class TestClassType(JitTestCase):
fn(2)
def test_conditional_set_attr(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "assignment cannot be in a control-flow block", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "assignment cannot be in a control-flow block", ""
):
@torch.jit.script
class FooTest:
def __init__(self, x):
@ -236,7 +257,6 @@ class TestClassType(JitTestCase):
# classes are globally registered for now, so we need to clear the JIT
# registry to simulate loading a new model
buffer.seek(0)
m_loaded = torch.jit.load(buffer)
@ -320,7 +340,7 @@ class TestClassType(JitTestCase):
self.x = x
self.y = y
make_global(Foo) # see [local resolution in python]
make_global(Foo) # see [local resolution in python]
@torch.jit.script
def use_foo(foo: Foo) -> Foo:
@ -419,15 +439,22 @@ class TestClassType(JitTestCase):
self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)])
with self.assertRaisesRegexWithHighlight(RuntimeError, "bool\' for argument \'reverse", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "bool' for argument 'reverse", ""
):
@torch.jit.script
def test():
li = [Foo(1)]
li.sort(li)
return li
test()
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "must define a __lt__", ""
):
@torch.jit.script
class NoMethod:
def __init__(self):
@ -438,6 +465,7 @@ class TestClassType(JitTestCase):
li = [NoMethod(), NoMethod()]
li.sort()
return li
test()
@torch.jit.script
@ -449,12 +477,16 @@ class TestClassType(JitTestCase):
def __lt__(self, other):
pass
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "must define a __lt__", ""
):
@torch.jit.script
def test():
li = [WrongLt(), WrongLt()]
li.sort()
return li
test()
def test_class_inheritance(self):
@ -466,18 +498,21 @@ class TestClassType(JitTestCase):
def two(self, x):
return x + self.b
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not support inheritance", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not support inheritance", ""
):
@torch.jit.script
class Derived(Base):
def two(self, x):
return x + self.b + 2
def test_class_inheritance_implicit(self):
"""
Test that inheritance is detected in
implicit scripting codepaths (e.g. try_ann_to_type).
"""
class A:
def __init__(self, t):
self.t = t
@ -502,14 +537,16 @@ class TestClassType(JitTestCase):
else:
return B.f(x.t)
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", ""
):
sc = torch.jit.script(fun)
@skipIfTorchDynamo("Test does not work with TorchDynamo")
@unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode")
def test_imported_classes(self):
import jit._imported_class_test.foo
import jit._imported_class_test.bar
import jit._imported_class_test.foo
import jit._imported_class_test.very.very.nested
class MyMod(torch.jit.ScriptModule):
@ -593,6 +630,7 @@ class TestClassType(JitTestCase):
def one(self, x, y):
return x + y
# missing two
@torch.jit.script
@ -616,6 +654,7 @@ class TestClassType(JitTestCase):
x = c[i].one(x, x)
x = c[i].two(x)
return x
self.checkScript(use_them, (torch.rand(3, 4),))
@torch.jit.script
@ -626,22 +665,33 @@ class TestClassType(JitTestCase):
def inherit(x: OneTwoThree) -> OneTwo:
return as_interface(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not have method", ""
):
@torch.jit.script
def wrong1():
return as_interface(NotMember())
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "is not compatible with interface", ""
):
@torch.jit.script
def wrong2():
return as_interface(NotMember2())
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not have method", ""
):
@torch.jit.script
def wrong3():
return inherit(as_interface(Foo()))
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "is not compatible with interface", ""
):
@torch.jit.script
def wrong4(x: OneTwoWrong) -> int:
@ -656,7 +706,7 @@ class TestClassType(JitTestCase):
def forward(self, x):
return self.proxy_mod.two(x)
TestPyAssign.__annotations__ = {'proxy_mod': OneTwo}
TestPyAssign.__annotations__ = {"proxy_mod": OneTwo}
input = torch.rand(3, 4)
scripted_pyassign_mod = torch.jit.script(TestPyAssign())
@ -671,10 +721,11 @@ class TestClassType(JitTestCase):
def forward(self, x):
return self.proxy_mod.two(x)
TestPyAssignError.__annotations__ = {'proxy_mod': OneTwoThree}
TestPyAssignError.__annotations__ = {"proxy_mod": OneTwoThree}
with self.assertRaisesRegexWithHighlight(RuntimeError,
"is not compatible with interface __torch__", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "is not compatible with interface __torch__", ""
):
torch.jit.script(TestPyAssignError(Foo()))
# test pure python object assignment to interface fails
@ -682,8 +733,9 @@ class TestClassType(JitTestCase):
def __init__(self):
pass
with self.assertRaisesRegexWithHighlight(RuntimeError,
"the value is not a TorchScript compatible type", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "the value is not a TorchScript compatible type", ""
):
torch.jit.script(TestPyAssignError(PyClass()))
# TODO test: interface-interface class-interface inheritance errors,
# NamedTuple inheritance errors
@ -729,7 +781,7 @@ class TestClassType(JitTestCase):
return self.x * other
def __pow__(self, other: int) -> int:
return int(self.x ** other)
return int(self.x**other)
def __truediv__(self, other: int) -> float:
return self.x / other
@ -773,54 +825,89 @@ class TestClassType(JitTestCase):
def __call__(self, val: int) -> int:
return self.x * val * 3
make_global(Foo) # see [local resolution in python]
def add():
return MyClass(4) + 3
def sub(): # noqa: E306
return MyClass(4) - 3
def mul(): # noqa: E306
return MyClass(4) * 3
def pow(): # noqa: E306
return MyClass(4) ** 3
def truediv(): # noqa: E306
return MyClass(4) / 3
def ne(): # noqa: E306
return MyClass(4) != 3
def eq(): # noqa: E306
return MyClass(4) == 3
def lt(): # noqa: E306
return MyClass(4) < 3
def gt(): # noqa: E306
return MyClass(4) > 3
def le(): # noqa: E306
return MyClass(4) <= 3
def ge(): # noqa: E306
return MyClass(4) >= 3
def _and(): # noqa: E306
return MyClass(4) & 3
def _or(): # noqa: E306
return MyClass(4) | 3
def _xor(): # noqa: E306
return MyClass(4) ^ 3
def getitem(): # noqa: E306
return MyClass(4)[1]
def setitem(): # noqa: E306
a = MyClass(4)
a[1] = 5
return a.x
def call(): # noqa: E306
a = MyClass(5)
return a(2)
ops = [add, sub, mul, pow, ne, eq, lt, gt, le, ge, _and, _or, _xor, getitem, setitem, call]
ops = [
add,
sub,
mul,
pow,
ne,
eq,
lt,
gt,
le,
ge,
_and,
_or,
_xor,
getitem,
setitem,
call,
]
ops.append(truediv)
for func in ops:
self.checkScript(func, ())
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", ""
):
@torch.jit.script
def test():
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
@ -852,7 +939,7 @@ class TestClassType(JitTestCase):
fn = torch.jit.script(test)
self.assertEqual(fn(Foo(0.5)), test(0.5))
self.assertEqual(fn(Foo(0.)), test(0.0))
self.assertEqual(fn(Foo(0.0)), test(0.0))
# str has slightly different formatting
self.assertTrue("0.5" in (str(Foo(0.5))))
self.assertTrue("0." in (str(Foo(0.0))))
@ -865,7 +952,10 @@ class TestClassType(JitTestCase):
def __bool__(self):
return (1, 2)
with self.assertRaisesRegexWithHighlight(RuntimeError, "expected a bool expression for condition", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected a bool expression for condition", ""
):
@torch.jit.script
def test():
if BadBool():
@ -921,6 +1011,7 @@ class TestClassType(JitTestCase):
Recursive class types not yet supported. We should give a good error message.
"""
with self.assertRaises(RuntimeError):
@torch.jit.script # noqa: B903
class Tree: # noqa: B903
def __init__(self):
@ -940,7 +1031,7 @@ class TestClassType(JitTestCase):
return x, y
# Test serialization/deserialization of class constant
for c in (2, 1.0, None, True, 'str', (2, 3), [5.9, 7.3]):
for c in (2, 1.0, None, True, "str", (2, 3), [5.9, 7.3]):
m = torch.jit.script(M(c))
buffer = io.BytesIO()
torch.jit.save(m, buffer)
@ -954,28 +1045,31 @@ class TestClassType(JitTestCase):
def test_py_class_to_ivalue_missing_attribute(self):
class Foo:
i : int
f : float
i: int
f: float
def __init__(self, i : int, f : float):
def __init__(self, i: int, f: float):
self.i = i
self.f = f
make_global(Foo) # see [local resolution in python]
@torch.jit.script
def test_fn(x : Foo) -> float:
def test_fn(x: Foo) -> float:
return x.i + x.f
test_fn(Foo(3, 4.0))
with self.assertRaisesRegexWithHighlight(RuntimeError, 'missing attribute i', ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "missing attribute i", ""
):
test_fn(torch.rand(3, 4))
def test_unused_method(self):
"""
Test unused methods on scripted classes.
"""
@torch.jit.script
class Unused:
def __init__(self):
@ -1028,12 +1122,13 @@ class TestClassType(JitTestCase):
Test that a scripted class can have a method that refers to the class itself
in its type annotations.
"""
@torch.jit.script
class Meta:
def __init__(self, a: int):
self.a = a
def method(self, other: List['Meta']) -> 'Meta':
def method(self, other: List["Meta"]) -> "Meta":
return Meta(len(other))
class ModuleWithMeta(torch.nn.Module):
@ -1051,19 +1146,20 @@ class TestClassType(JitTestCase):
"""
Test that annotating container attributes with types works correctly
"""
@torch.jit.script
class CompetitiveLinkingTokenReplacementUtils:
def __init__(self):
self.my_list : List[Tuple[float, int, int]] = []
self.my_dict : Dict[int, int] = {}
self.my_list: List[Tuple[float, int, int]] = []
self.my_dict: Dict[int, int] = {}
@torch.jit.script
def foo():
y = CompetitiveLinkingTokenReplacementUtils()
new_dict : Dict[int, int] = {1: 1, 2: 2}
new_dict: Dict[int, int] = {1: 1, 2: 2}
y.my_dict = new_dict
new_list : List[Tuple[float, int, int]] = [(1.0, 1, 1)]
new_list: List[Tuple[float, int, int]] = [(1.0, 1, 1)]
y.my_list = new_list
return y
@ -1071,6 +1167,7 @@ class TestClassType(JitTestCase):
"""
Test that methods on class types can have default arguments.
"""
@torch.jit.script
class ClassWithDefaultArgs:
def __init__(
@ -1105,7 +1202,9 @@ class TestClassType(JitTestCase):
return obj.int + obj.list[2] + obj.dict[1]
def override_defaults() -> int:
obj: ClassWithDefaultArgs = ClassWithDefaultArgs(3, [9, 10, 11], (12, 13, 14), {3: 4}, "str")
obj: ClassWithDefaultArgs = ClassWithDefaultArgs(
3, [9, 10, 11], (12, 13, 14), {3: 4}, "str"
)
s: int = obj.int
for x in obj.list:
@ -1154,7 +1253,7 @@ class TestClassType(JitTestCase):
# The constructor of this class below has mutable arguments. This should throw
# an error.
class ClassWithMutableArgs: # noqa: B903
class ClassWithMutableArgs: # noqa: B903
def __init__(
self,
a: List[int] = [1, 2, 3], # noqa: B006
@ -1164,13 +1263,16 @@ class TestClassType(JitTestCase):
def should_fail():
obj: ClassWithMutableArgs = ClassWithMutableArgs()
with self.assertRaisesRegexWithHighlight(RuntimeError, "Mutable default parameters are not supported", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Mutable default parameters are not supported", ""
):
torch.jit.script(should_fail)
def test_staticmethod(self):
"""
Test static methods on class types.
"""
@torch.jit.script
class ClassWithStaticMethod:
def __init__(self, a: int, b: int):
@ -1183,22 +1285,22 @@ class TestClassType(JitTestCase):
def get_b(self):
return self.b
def __eq__(self, other: 'ClassWithStaticMethod'):
def __eq__(self, other: "ClassWithStaticMethod"):
return self.a == other.a and self.b == other.b
# staticmethod that calls constructor.
@staticmethod
def create(args: List['ClassWithStaticMethod']) -> 'ClassWithStaticMethod':
def create(args: List["ClassWithStaticMethod"]) -> "ClassWithStaticMethod":
return ClassWithStaticMethod(args[0].a, args[0].b)
# staticmethod that calls another staticmethod.
@staticmethod
def create_from(a: int, b: int) -> 'ClassWithStaticMethod':
def create_from(a: int, b: int) -> "ClassWithStaticMethod":
a = ClassWithStaticMethod(a, b)
return ClassWithStaticMethod.create([a])
# Script function that calls staticmethod.
def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
def test_function(a: int, b: int) -> "ClassWithStaticMethod":
return ClassWithStaticMethod.create_from(a, b)
make_global(ClassWithStaticMethod)
@ -1209,21 +1311,22 @@ class TestClassType(JitTestCase):
"""
Test classmethods on class types.
"""
@torch.jit.script
class ClassWithClassMethod:
def __init__(self, a: int):
self.a: int = a
def __eq__(self, other: 'ClassWithClassMethod'):
def __eq__(self, other: "ClassWithClassMethod"):
return self.a == other.a
@classmethod
def create(cls, a: int) -> 'ClassWithClassMethod':
def create(cls, a: int) -> "ClassWithClassMethod":
return cls(a)
make_global(ClassWithClassMethod)
def test_function(a: int) -> 'ClassWithClassMethod':
def test_function(a: int) -> "ClassWithClassMethod":
x = ClassWithClassMethod(a)
# Support calling classmethod with an instance
# Calling with the class is not supported.
@ -1236,6 +1339,7 @@ class TestClassType(JitTestCase):
"""
Test that a scripted class can make use of the @property decorator.
"""
def free_function(x: int) -> int:
return x + 1
@ -1308,13 +1412,22 @@ class TestClassType(JitTestCase):
return self.props.attr + no_setter.attr + method_uses_property.forward()
self.checkModule(ModuleWithProperties(5), (5, 6, 7, 8,))
self.checkModule(
ModuleWithProperties(5),
(
5,
6,
7,
8,
),
)
def test_custom_delete(self):
"""
Test that del can be called on an instance of a class that
overrides __delitem__.
"""
class Example:
def __init__(self):
self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)}
@ -1346,7 +1459,9 @@ class TestClassType(JitTestCase):
del example[key]
return example.check(key)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Class does not define __delitem__", "example[key]"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Class does not define __delitem__", "example[key]"
):
self.checkScript(fn, ())
def test_recursive_script_builtin_type_resolution(self):
@ -1369,7 +1484,7 @@ class TestClassType(JitTestCase):
def g(self, x: device_t) -> device_ty:
return x
def h(self, a: 'A') -> 'A':
def h(self, a: "A") -> "A":
return A()
def i(self, a: List[int]) -> int:
@ -1404,14 +1519,14 @@ class TestClassType(JitTestCase):
Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled
when compiling a module.
"""
class Wrapper():
class Wrapper:
def __init__(self, t):
self.t = t
def to(self, l: List[torch.device], device: Optional[torch.device] = None):
return self.t.to(device=device)
class A(nn.Module):
def forward(self):
return Wrapper(torch.rand(4, 4))
@ -1424,6 +1539,7 @@ class TestClassType(JitTestCase):
Test that the error message displayed when convering a class type
to an IValue that has an attribute of the wrong type.
"""
@torch.jit.script # noqa: B903
class ValHolder: # noqa: B903
def __init__(self, val):
@ -1442,7 +1558,9 @@ class TestClassType(JitTestCase):
mod = self.mod2
return mod.val
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not cast attribute 'val' to type Tensor", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Could not cast attribute 'val' to type Tensor", ""
):
torch.jit.script(Mod())
def test_recursive_scripting(self):
@ -1450,6 +1568,7 @@ class TestClassType(JitTestCase):
Test that class types are recursively scripted when an Python instance of one
is encountered as a module attribute.
"""
class Class:
def __init__(self, a: int):
self.a = a
@ -1473,6 +1592,7 @@ class TestClassType(JitTestCase):
are added as failed attributes and do not cause compilation itself
to fail unless they are used in scripted code.
"""
class UnscriptableClass:
def __init__(self, a: int):
self.a = a
@ -1490,7 +1610,9 @@ class TestClassType(JitTestCase):
def forward(self) -> bool:
return self.obj.get_a()
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to convert Python type", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to convert Python type", ""
):
torch.jit.script(ShouldNotCompile(UnscriptableClass(4)))
# This Module has an attribute of type UnscriptableClass
@ -1509,7 +1631,6 @@ class TestClassType(JitTestCase):
self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,))
def test_unresolved_class_attributes(self):
class UnresolvedAttrClass:
def __init__(self):
@ -1538,7 +1659,9 @@ class TestClassType(JitTestCase):
u = UnresolvedAttrClass()
return u.attr_e
error_message_regex = "object has no attribute or method.*is defined as a class attribute"
error_message_regex = (
"object has no attribute or method.*is defined as a class attribute"
)
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
with self.assertRaisesRegex(RuntimeError, error_message_regex):
torch.jit.script(fn)

View File

@ -1,19 +1,21 @@
# Owner(s): ["oncall: jit"]
import torch
import cmath
import os
import sys
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from torch.testing._internal.common_utils import IS_MACOS
from typing import List, Dict
from itertools import product
from textwrap import dedent
import cmath
from typing import Dict, List
import torch
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
class TestComplex(JitTestCase):
def test_script(self):
def fn(a: complex):
@ -32,7 +34,7 @@ class TestComplex(JitTestCase):
def fn(a: Dict[complex, complex], key: complex) -> complex:
return a[key]
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
self.checkScript(fn, (input, -4.3 - 2j))
def test_pickle(self):
@ -41,7 +43,7 @@ class TestComplex(JitTestCase):
super().__init__()
self.a = 3 + 5j
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
@torch.jit.script_method
def forward(self, b: int):
@ -50,7 +52,7 @@ class TestComplex(JitTestCase):
loaded = self.getExportImportCopy(ComplexModule())
self.assertEqual(loaded.a, 3 + 5j)
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})
self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
self.assertEqual(loaded(2), 2 + 2j)
def test_complex_parse(self):
@ -65,14 +67,19 @@ class TestComplex(JitTestCase):
self.checkScript(fn, (t1, t2, 2))
def test_complex_constants_and_ops(self):
vals = ([0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
+ [10.0 ** i for i in range(2)] + [-(10.0 ** i) for i in range(2)])
vals = (
[0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
+ [10.0**i for i in range(2)]
+ [-(10.0**i) for i in range(2)]
)
complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
funcs_template = dedent('''
funcs_template = dedent(
"""
def func(a: complex):
return cmath.{func_or_const}(a)
''')
"""
)
def checkCmath(func_name, funcs_template=funcs_template):
funcs_str = funcs_template.format(func_or_const=func_name)
@ -80,11 +87,13 @@ class TestComplex(JitTestCase):
execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str)
f_script = cu.func
f = scope['func']
f = scope["func"]
if func_name in ['isinf', 'isnan', 'isfinite']:
new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')])
final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals))
if func_name in ["isinf", "isnan", "isfinite"]:
new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
final_vals = tuple(
complex(x, y) for x, y in product(new_vals, new_vals)
)
else:
final_vals = complex_vals
@ -107,8 +116,27 @@ class TestComplex(JitTestCase):
msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
self.assertEqual(res_python, res_script, msg=msg)
unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh',
'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite']
unary_ops = [
"log",
"log10",
"sqrt",
"exp",
"sin",
"cos",
"asin",
"acos",
"atan",
"sinh",
"cosh",
"tanh",
"asinh",
"acosh",
"atanh",
"phase",
"isinf",
"isnan",
"isfinite",
]
# --- Unary ops ---
for op in unary_ops:
@ -118,7 +146,7 @@ class TestComplex(JitTestCase):
return abs(x)
for val in complex_vals:
self.checkScript(fn, (val, ))
self.checkScript(fn, (val,))
def pow_complex_float(x: complex, y: float):
return pow(x, y)
@ -126,7 +154,6 @@ class TestComplex(JitTestCase):
def pow_float_complex(x: float, y: complex):
return pow(x, y)
self.checkScript(pow_float_complex, (2, 3j))
self.checkScript(pow_complex_float, (3j, 2))
@ -135,7 +162,7 @@ class TestComplex(JitTestCase):
for x, y in zip(complex_vals, complex_vals):
# Reference: https://github.com/pytorch/pytorch/issues/54622
if (x == 0):
if x == 0:
continue
self.checkScript(pow_complex_complex, (x, y))
@ -143,16 +170,25 @@ class TestComplex(JitTestCase):
# --- Binary op ---
def rect_fn(x: float, y: float):
return cmath.rect(x, y)
for x, y in product(vals, vals):
self.checkScript(rect_fn, (x, y, ))
func_constants_template = dedent('''
for x, y in product(vals, vals):
self.checkScript(
rect_fn,
(
x,
y,
),
)
func_constants_template = dedent(
"""
def func():
return cmath.{func_or_const}
''')
float_consts = ['pi', 'e', 'tau', 'inf', 'nan']
complex_consts = ['infj', 'nanj']
for x in (float_consts + complex_consts):
"""
)
float_consts = ["pi", "e", "tau", "inf", "nan"]
complex_consts = ["infj", "nanj"]
for x in float_consts + complex_consts:
checkCmath(x, funcs_template=func_constants_template)
def test_infj_nanj_pickle(self):
@ -177,77 +213,293 @@ class TestComplex(JitTestCase):
def fn_int(real: int, img: int):
return complex(real, img)
self.checkScript(fn_int, (0, 0, ))
self.checkScript(fn_int, (-1234, 0, ))
self.checkScript(fn_int, (0, -1256, ))
self.checkScript(fn_int, (-167, -1256, ))
self.checkScript(
fn_int,
(
0,
0,
),
)
self.checkScript(
fn_int,
(
-1234,
0,
),
)
self.checkScript(
fn_int,
(
0,
-1256,
),
)
self.checkScript(
fn_int,
(
-167,
-1256,
),
)
def fn_float(real: float, img: float):
return complex(real, img)
self.checkScript(fn_float, (0.0, 0.0, ))
self.checkScript(fn_float, (-1234.78, 0, ))
self.checkScript(fn_float, (0, 56.18, ))
self.checkScript(fn_float, (-1.9, -19.8, ))
self.checkScript(
fn_float,
(
0.0,
0.0,
),
)
self.checkScript(
fn_float,
(
-1234.78,
0,
),
)
self.checkScript(
fn_float,
(
0,
56.18,
),
)
self.checkScript(
fn_float,
(
-1.9,
-19.8,
),
)
def fn_bool(real: bool, img: bool):
return complex(real, img)
self.checkScript(fn_bool, (True, True, ))
self.checkScript(fn_bool, (False, False, ))
self.checkScript(fn_bool, (False, True, ))
self.checkScript(fn_bool, (True, False, ))
self.checkScript(
fn_bool,
(
True,
True,
),
)
self.checkScript(
fn_bool,
(
False,
False,
),
)
self.checkScript(
fn_bool,
(
False,
True,
),
)
self.checkScript(
fn_bool,
(
True,
False,
),
)
def fn_bool_int(real: bool, img: int):
return complex(real, img)
self.checkScript(fn_bool_int, (True, 0, ))
self.checkScript(fn_bool_int, (False, 0, ))
self.checkScript(fn_bool_int, (False, -1, ))
self.checkScript(fn_bool_int, (True, 3, ))
self.checkScript(
fn_bool_int,
(
True,
0,
),
)
self.checkScript(
fn_bool_int,
(
False,
0,
),
)
self.checkScript(
fn_bool_int,
(
False,
-1,
),
)
self.checkScript(
fn_bool_int,
(
True,
3,
),
)
def fn_int_bool(real: int, img: bool):
return complex(real, img)
self.checkScript(fn_int_bool, (0, True, ))
self.checkScript(fn_int_bool, (0, False, ))
self.checkScript(fn_int_bool, (-3, True, ))
self.checkScript(fn_int_bool, (6, False, ))
self.checkScript(
fn_int_bool,
(
0,
True,
),
)
self.checkScript(
fn_int_bool,
(
0,
False,
),
)
self.checkScript(
fn_int_bool,
(
-3,
True,
),
)
self.checkScript(
fn_int_bool,
(
6,
False,
),
)
def fn_bool_float(real: bool, img: float):
return complex(real, img)
self.checkScript(fn_bool_float, (True, 0.0, ))
self.checkScript(fn_bool_float, (False, 0.0, ))
self.checkScript(fn_bool_float, (False, -1.0, ))
self.checkScript(fn_bool_float, (True, 3.0, ))
self.checkScript(
fn_bool_float,
(
True,
0.0,
),
)
self.checkScript(
fn_bool_float,
(
False,
0.0,
),
)
self.checkScript(
fn_bool_float,
(
False,
-1.0,
),
)
self.checkScript(
fn_bool_float,
(
True,
3.0,
),
)
def fn_float_bool(real: float, img: bool):
return complex(real, img)
self.checkScript(fn_float_bool, (0.0, True, ))
self.checkScript(fn_float_bool, (0.0, False, ))
self.checkScript(fn_float_bool, (-3.0, True, ))
self.checkScript(fn_float_bool, (6.0, False, ))
self.checkScript(
fn_float_bool,
(
0.0,
True,
),
)
self.checkScript(
fn_float_bool,
(
0.0,
False,
),
)
self.checkScript(
fn_float_bool,
(
-3.0,
True,
),
)
self.checkScript(
fn_float_bool,
(
6.0,
False,
),
)
def fn_float_int(real: float, img: int):
return complex(real, img)
self.checkScript(fn_float_int, (0.0, 1, ))
self.checkScript(fn_float_int, (0.0, -1, ))
self.checkScript(fn_float_int, (1.8, -3, ))
self.checkScript(fn_float_int, (2.7, 8, ))
self.checkScript(
fn_float_int,
(
0.0,
1,
),
)
self.checkScript(
fn_float_int,
(
0.0,
-1,
),
)
self.checkScript(
fn_float_int,
(
1.8,
-3,
),
)
self.checkScript(
fn_float_int,
(
2.7,
8,
),
)
def fn_int_float(real: int, img: float):
return complex(real, img)
self.checkScript(fn_int_float, (1, 0.0, ))
self.checkScript(fn_int_float, (-1, 1.7, ))
self.checkScript(fn_int_float, (-3, 0.0, ))
self.checkScript(fn_int_float, (2, -8.9, ))
self.checkScript(
fn_int_float,
(
1,
0.0,
),
)
self.checkScript(
fn_int_float,
(
-1,
1.7,
),
)
self.checkScript(
fn_int_float,
(
-3,
0.0,
),
)
self.checkScript(
fn_int_float,
(
2,
-8.9,
),
)
def test_torch_complex_constructor_with_tensor(self):
tensors = ([torch.rand(1), torch.randint(-5, 5, (1, )), torch.tensor([False])])
tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
def fn_tensor_float(real, img: float):
return complex(real, img)
@ -280,7 +532,13 @@ class TestComplex(JitTestCase):
return complex(real, img) + complex(2)
for x, y in product(tensors, tensors):
self.checkScript(fn_tensor_tensor, (x, y, ))
self.checkScript(
fn_tensor_tensor,
(
x,
y,
),
)
def test_comparison_ops(self):
def fn1(a: complex, b: complex):
@ -316,7 +574,7 @@ class TestComplex(JitTestCase):
def fn(x: List[complex]):
return sum(x)
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), ))
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
def test_tensor_attributes(self):
def tensor_real(x):
@ -326,8 +584,8 @@ class TestComplex(JitTestCase):
return x.imag
t = torch.randn(2, 3, dtype=torch.cdouble)
self.checkScript(tensor_real, (t, ))
self.checkScript(tensor_imag, (t, ))
self.checkScript(tensor_real, (t,))
self.checkScript(tensor_imag, (t,))
def test_binary_op_complex_tensor(self):
def mul(x: complex, y: torch.Tensor):
@ -350,7 +608,7 @@ class TestComplex(JitTestCase):
ops = [mul, add, eq, ne, sub, div]
for shape in [(1, ), (2, 2)]:
for shape in [(1,), (2, 2)]:
x = 0.71 + 0.71j
y = torch.randn(shape, dtype=torch.cfloat)
for op in ops:

View File

@ -10,18 +10,29 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, enable_profiling_mode
from torch.testing._internal.jit_metaprogramming_utils import try_get_nn_module_compiled_mod_and_inputs, \
get_nn_mod_test_name, get_all_nn_module_tests, nn_functional_tests, get_nn_functional_compiled_fn_and_inputs
from torch.testing._internal.common_utils import run_tests, set_default_dtype, suppress_warnings, IS_FBCODE
from torch.testing._internal.common_utils import (
IS_FBCODE,
run_tests,
set_default_dtype,
suppress_warnings,
)
from torch.testing._internal.jit_metaprogramming_utils import (
get_all_nn_module_tests,
get_nn_functional_compiled_fn_and_inputs,
get_nn_mod_test_name,
nn_functional_tests,
try_get_nn_module_compiled_mod_and_inputs,
)
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
def num_ifs_loops(graph):
graph_str = str(graph)
# only look at body of graph
graph_body = graph_str[0:graph_str.find("return")]
graph_body = graph_str[0 : graph_str.find("return")]
return graph_body.count("prim::Loop") + graph_body.count("prim::If")
def num_non_tensor_nodes(block):
num_non_tensor = 0
for node in block.nodes():
@ -40,6 +51,7 @@ def num_non_tensor_nodes(block):
num_non_tensor += int(not tensor_out)
return num_non_tensor
class TestComplexity(JitTestCase):
def setUp(self):
super().setUp()
@ -90,5 +102,6 @@ class TestComplexity(JitTestCase):
for line in stats:
print(line)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -2,16 +2,18 @@
import os
import sys
import unittest
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck
import unittest
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
@ -22,10 +24,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
activations = [
F.celu,
@ -41,6 +45,7 @@ activations = [
F.silu,
]
class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self):
dtypes = [
@ -67,6 +72,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
def test_functional_to_inplace_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
z = activation(y)
@ -76,7 +82,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass('functional_to_inplace_activation', fn.graph)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2])
@ -91,7 +97,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return z
fn = torch.jit.script(test1)
self.run_pass('functional_to_inplace_activation', fn.graph)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias
@ -102,7 +108,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return z
fn = torch.jit.script(test2)
self.run_pass('functional_to_inplace_activation', fn.graph)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is
@ -117,22 +123,33 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass('functional_to_inplace_activation', fn.graph)
self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
(
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W)
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
self.run_pass("functional_to_inplace_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self):
for activation in activations:
def test_basic(x):
y = x + 1
activation(y, inplace=True)
@ -142,7 +159,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
@ -151,6 +168,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
torch.sigmoid_,
torch.tanh_,
]:
def test_basic(x):
y = x + 1
activation(y)
@ -160,7 +178,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph)
self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
@ -171,7 +189,17 @@ class TestInplaceToFunctionalActivation(JitTestCase):
def test_resnet18_correctness(self):
model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224
(
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W)
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp))

View File

@ -1,16 +1,21 @@
# Owner(s): ["oncall: jit"]
import gc
import os
import sys
import gc
import unittest
from typing import NamedTuple
import torch
from typing import NamedTuple
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf, NoTest, TEST_CUDA
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
NoTest,
skipCUDANonDefaultStreamIf,
skipIfRocm,
TEST_CUDA,
)
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -18,7 +23,7 @@ sys.path.append(pytorch_test_dir)
# If GPU is not available, then do not run the tests
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
print("CUDA not available, skipping tests", file=sys.stderr)
JitTestCase = NoTest # noqa: F811
TEST_LARGE_TENSOR = TEST_CUDA
@ -36,10 +41,12 @@ if __name__ == "__main__":
"instead."
)
class TestCUDA(JitTestCase):
"""
A suite of tests for the CUDA API in TorchScript.
"""
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@ -54,10 +61,10 @@ class TestCUDA(JitTestCase):
def test_device_synchronize():
prev_current_device_index = torch.cuda.current_device()
torch.cuda.synchronize()
torch.cuda.synchronize('cuda')
torch.cuda.synchronize('cuda:0')
torch.cuda.synchronize("cuda")
torch.cuda.synchronize("cuda:0")
torch.cuda.synchronize(0)
torch.cuda.synchronize(torch.device('cuda:1'))
torch.cuda.synchronize(torch.device("cuda:1"))
after_current_device_index = torch.cuda.current_device()
# Check if the current device index is same as the device index before
@ -66,7 +73,7 @@ class TestCUDA(JitTestCase):
@torch.jit.script
def test_multi_device_synchronize():
torch.cuda.synchronize(torch.device('cuda:0'))
torch.cuda.synchronize(torch.device("cuda:0"))
prev_current_device_index = torch.cuda.current_device()
torch.cuda.synchronize(1)
after_current_device_index = torch.cuda.current_device()
@ -76,11 +83,9 @@ class TestCUDA(JitTestCase):
return prev_current_device_index == after_current_device_index
self.assertTrue(test_device_synchronize)
FileCheck().check("cuda::synchronize(") \
.run(test_device_synchronize.graph)
FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph)
self.assertTrue(test_multi_device_synchronize)
FileCheck().check("cuda::synchronize(") \
.run(test_multi_device_synchronize.graph)
FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph)
def test_stream_args(self):
# Test stream creation with default arguments
@ -165,7 +170,6 @@ class TestCUDA(JitTestCase):
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@skipCUDANonDefaultStreamIf(True)
def test_streams_and_events(self):
# Test default_stream API by passing device ID as an argument and
# and check if the stream device index matches with the device ID
@torch.jit.script
@ -182,14 +186,14 @@ class TestCUDA(JitTestCase):
# This test checks for the default stream ID is set to 0 on the device
@torch.jit.script
def test_default_streams():
s0 = torch.cuda.default_stream(torch.device('cuda:0'))
s1 = torch.cuda.default_stream(torch.device('cuda:1'))
s0 = torch.cuda.default_stream(torch.device("cuda:0"))
s1 = torch.cuda.default_stream(torch.device("cuda:1"))
d = torch.device('cuda:1')
d = torch.device("cuda:1")
# Check the current stream id and default id are same
# on the current device. The current device id by default is 0
s2 = torch.cuda.current_stream(torch.device('cuda:0'))
s2 = torch.cuda.current_stream(torch.device("cuda:0"))
check_s2 = s2.id() == s0.id()
check_d0 = torch.cuda.current_device() == s2.device_index()
@ -203,9 +207,25 @@ class TestCUDA(JitTestCase):
# Check if the current device was reset to 0
is_device_d0 = torch.cuda.current_device() == s2.device_index()
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
return (
s0.device_index(),
s1.device_index(),
check_s2,
check_s3,
check_d0,
check_d1,
is_device_d0,
)
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
(
d0,
d1,
check_s2,
check_s3,
check_d0,
check_d1,
is_device_d0,
) = test_default_streams()
self.assertEqual(d0, 0)
self.assertEqual(d1, 1)
@ -228,12 +248,21 @@ class TestCUDA(JitTestCase):
with torch.cuda.stream(None):
cur_device_index = torch.cuda.current_device()
is_device_index_same = cur_device_index == device_index
is_current_stream_same = torch.cuda.current_stream(device).id() == current_stream.id()
is_default_stream_same = torch.cuda.default_stream(device).id() == default_stream.id()
is_current_stream_same = (
torch.cuda.current_stream(device).id() == current_stream.id()
)
is_default_stream_same = (
torch.cuda.default_stream(device).id() == default_stream.id()
)
# Check if the device index, current stream and default streams have not changed
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
are_streams_same = (
is_device_index_same
and is_current_stream_same
and is_default_stream_same
)
return are_streams_same
self.assertTrue(test_set_none_stream())
# This test checks if the Device Context manager is a no op
@ -246,6 +275,7 @@ class TestCUDA(JitTestCase):
# Check if the current device is the same
is_device_same = torch.cuda.current_device() == device_index
return is_device_same
self.assertTrue(test_set_device_none())
# Check if a CUDA JIT stream is created
@ -260,15 +290,15 @@ class TestCUDA(JitTestCase):
# Class used to store results for the test: test_get_stream.
class Result(NamedTuple):
t1 : torch.Tensor
t2 : torch.Tensor
is_current_and_default_stream_same : bool
is_default_and_user_stream_not_same : bool
is_stream_set : bool
is_stream_reset : bool
default_stream_query : bool
default_stream_id : int
user_stream_id : int
t1: torch.Tensor
t2: torch.Tensor
is_current_and_default_stream_same: bool
is_default_and_user_stream_not_same: bool
is_stream_set: bool
is_stream_reset: bool
default_stream_query: bool
default_stream_id: int
user_stream_id: int
# The test aims at checking different stream proporties.
@torch.jit.script
@ -280,15 +310,23 @@ class TestCUDA(JitTestCase):
user_stream = torch.cuda.Stream()
# Check if the current and default streams are the same on the device
is_current_and_default_stream_same = current_stream.id() == default_stream.id()
is_current_and_default_stream_same = (
current_stream.id() == default_stream.id()
)
# Check if user stream and default stream are not the same on the device
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
is_default_and_user_stream_not_same = (
default_stream.id() != user_stream.id()
)
with torch.cuda.stream(user_stream):
is_stream_set = torch.cuda.current_stream(device).id() == user_stream.id()
is_stream_set = (
torch.cuda.current_stream(device).id() == user_stream.id()
)
# Check if the stream was reset to current_stream
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id()
is_stream_reset = (
torch.cuda.current_stream(device).id() == current_stream.id()
)
tensor1 = torch.rand(10000, 10000, device="cuda")
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
@ -297,9 +335,16 @@ class TestCUDA(JitTestCase):
# Capture all the results in the class Result
res = Result(
tensor1, tensor2, is_current_and_default_stream_same,
is_default_and_user_stream_not_same, is_stream_set,
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
tensor1,
tensor2,
is_current_and_default_stream_same,
is_default_and_user_stream_not_same,
is_stream_set,
is_stream_reset,
default_stream_query,
default_stream.id(),
user_stream.id(),
)
return res
result = test_get_stream()
@ -310,8 +355,12 @@ class TestCUDA(JitTestCase):
self.assertTrue(result.is_stream_set)
self.assertTrue(result.is_stream_reset)
self.assertTrue(result.default_stream_query)
self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0
self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero
self.assertEqual(
result.default_stream_id, 0
) # Check if the default stream ID is always 0
self.assertNotEqual(
result.user_stream_id, 0
) # Check if the user stream is always non zero
# Test the stream context manager. This test checks if the stream is switched
# to the user stream on using the stream context manager.
@ -329,14 +378,20 @@ class TestCUDA(JitTestCase):
# Wait for B to be computed
user_stream.synchronize()
# Check if the stream has been reset on the current device
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id()
is_stream_reset = (
torch.cuda.current_stream(device).id() == current_stream.id()
)
return A, B, check, is_stream_reset
A, B, is_stream_set, is_stream_reset = test_stream_context()
self.assertEqual(torch.matmul(A, A), B)
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
self.assertTrue(
is_stream_set, "Error: Current stream was not set to user stream!"
)
self.assertTrue(
is_stream_reset, "Error: The stream was not restored to previous stream!"
)
# Test multiple nested streams. Check if the operations are computed as expected on the streams
# This test has been adapted from the eager mode tests available at test/test_cuda.py
@ -372,11 +427,24 @@ class TestCUDA(JitTestCase):
# Check if the stream and device has been restored to previous stream and device
is_device_current = torch.cuda.current_device() == prev_device_index
is_stream_current = torch.cuda.current_stream(device).id() == prev_current_stream.id()
is_stream_current = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
check_stream = (
is_stream_s1
and is_stream_s2
and is_stream_s1_after
and is_stream_current
)
check_device = (
is_device_s1
and is_device_s2
and is_device_s1_after
and is_device_current
)
return A, B, C, D, check_stream, check_device
A, B, C, D, check_stream, check_device = test_multiple_stream()
self.assertEqual(torch.matmul(A, A), C)
@ -401,7 +469,9 @@ class TestCUDA(JitTestCase):
B = torch.mm(A, A).to("cuda")
s1.record_event(event)
# Check if the current_stream is reset
is_current_stream_1 = torch.cuda.current_stream(device).id() == prev_current_stream.id()
is_current_stream_1 = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
# Wait for ops on s1 to be computed
s2.wait_event(event)
with torch.cuda.stream(s2):
@ -410,9 +480,16 @@ class TestCUDA(JitTestCase):
# Wait for C to be computed
s2.synchronize()
# Check if the current_stream is reset
is_current_stream_2 = torch.cuda.current_stream(device).id() == prev_current_stream.id()
is_current_stream_2 = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
check_stream = (
is_current_stream_1
and is_current_stream_2
and is_stream_s1
and is_stream_s2
)
return A, B, C, check_stream
A, B, C, check_stream = test_data_dependency_between_streams()
@ -425,6 +502,7 @@ class TestCUDA(JitTestCase):
def test_simple_event():
e = torch.cuda.Event(True, False, False)
return e is not None
self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
# Record the CUDA event for operation torch.mm on the current stream
@ -474,6 +552,7 @@ class TestCUDA(JitTestCase):
# not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise.
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_stream_synchronize(), 0)
# Test event synchronization for the event that records a stream doing
@ -536,12 +615,13 @@ class TestCUDA(JitTestCase):
# not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise.
return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_wait(), 0)
# Test for stream wait_event. Checks if the stream waits on the event
@torch.jit.script
def test_wait_event():
d1 = torch.device('cuda:1')
d1 = torch.device("cuda:1")
with torch.cuda.device(d1):
s0 = torch.cuda.current_stream(d1)
@ -550,11 +630,12 @@ class TestCUDA(JitTestCase):
e0 = torch.cuda.Event(False, False, False)
s0.record_event(e0)
s1 = torch.cuda.current_stream(torch.device('cuda:0'))
s1 = torch.cuda.current_stream(torch.device("cuda:0"))
s1.wait_event(e0)
s1.synchronize()
return e0.query() and s0.query() and s1.query()
self.assertTrue(test_wait_event())
# Test if a scripted module with cuda streams can be saved, loaded and executed

View File

@ -11,33 +11,37 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def canonical(graph):
return torch._C._jit_pass_canonicalize(graph).str(False)
class TestCustomOperators(JitTestCase):
class TestCustomOperators(JitTestCase):
def test_dynamic_op_registry(self):
from torch._ops import _OpNamespace
self.assertTrue(hasattr(torch, 'ops'))
if '_test' in torch.ops.__dict__:
torch.ops.__dict__.pop('_test')
self.assertTrue(hasattr(torch, "ops"))
if "_test" in torch.ops.__dict__:
torch.ops.__dict__.pop("_test")
# Don't use `hasattr()` because it will call `__getattr__`.
self.assertNotIn('_test', torch.ops.__dict__)
self.assertNotIn("_test", torch.ops.__dict__)
torch.ops._test
self.assertIn('_test', torch.ops.__dict__)
self.assertIn("_test", torch.ops.__dict__)
self.assertEqual(type(torch.ops._test), _OpNamespace)
self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
self.assertNotIn("leaky_relu", torch.ops._test.__dict__)
op = torch.ops._test.leaky_relu
self.assertTrue(callable(op))
self.assertIn('leaky_relu', torch.ops._test.__dict__)
self.assertIn("leaky_relu", torch.ops._test.__dict__)
op2 = torch.ops._test.leaky_relu
self.assertEqual(op, op2)
@ -46,7 +50,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight(
AttributeError,
f"Invalid attribute '{attr}' for '_OpNamespace' '_test'",
""
"",
):
getattr(torch.ops._test, attr)
@ -63,15 +67,13 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)",
""
"",
):
torch.ops.aten.relu(1, 2)
def test_passing_too_few_args(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::relu\(\) is missing value for argument 'self'.",
""
RuntimeError, r"aten::relu\(\) is missing value for argument 'self'.", ""
):
torch.ops.aten.relu()
@ -79,7 +81,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
r"aten::type_as\(\) is missing value for argument 'other'.",
""
"",
):
torch.ops.aten.type_as(torch.ones(5, 5))
@ -87,7 +89,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'",
""
"",
):
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
@ -102,6 +104,7 @@ class TestCustomOperators(JitTestCase):
@torch.jit.script
def func(x):
return torch.ops.aten.relu(x)
input = torch.ones(5, 5)
self.assertEqual(func(input), input.relu())
@ -110,28 +113,37 @@ class TestCustomOperators(JitTestCase):
func = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertEqual(func(input), input.relu())
@unittest.skip("Need to figure out default dtype differences between fbcode and oss")
@unittest.skip(
"Need to figure out default dtype differences between fbcode and oss"
)
def test_script_graph_for_custom_ops_matches_traced_graph(self):
input = torch.ones(5, 5)
trace = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertExpectedInline(canonical(trace.graph), '''\
self.assertExpectedInline(
canonical(trace.graph),
"""\
graph(%0 : Float(5, 5)):
%1 : Float(5, 5) = aten::relu(%0)
return (%1)
''')
""",
)
def test_script_graph_contains_custom_op(self):
@torch.jit.script
def func(x):
return torch.ops.aten.relu(x)
self.assertExpectedInline(canonical(func.graph), '''\
self.assertExpectedInline(
canonical(func.graph),
"""\
graph(%x.1 : Tensor):
%1 : Tensor = aten::relu(%x.1)
return (%1)
''')
""",
)
def test_generic_list(self):
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
self.assertEqual(torch.ops._test.get_first([["hello"]]), "hello")
# https://github.com/pytorch/pytorch/issues/80508
def test_where_no_scalar(self):

View File

@ -13,17 +13,21 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDataParallel(JitTestCase):
class Mpy(torch.nn.Module):
def __init__(self):
super(TestDataParallel.Mpy, self).__init__()
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
nn.ReLU(), nn.Linear(2, 2))
self.m = nn.Sequential(
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.ignore
def forward(self, input):
@ -50,13 +54,13 @@ class TestDataParallel(JitTestCase):
return self.m2(x)
class Msm(torch.jit.ScriptModule):
__constants__ = ['m']
__constants__ = ["m"]
def __init__(self):
super(TestDataParallel.Msm, self).__init__()
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
nn.ReLU(), nn.Linear(2, 2))
self.m = nn.Sequential(
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.script_method
def forward(self, input):
@ -140,7 +144,7 @@ class TestDataParallel(JitTestCase):
# Use .data here to avoid version counter bump.
# The graph created by the following forward will be wrong but
# we never backward through them so it's fine
p.data -= 1. * p.grad
p.data -= 1.0 * p.grad
second_forward = module(x)
# replica which is on the same GPU has a shallow copy of the original

View File

@ -1,14 +1,16 @@
# Owner(s): ["oncall: jit"]
# flake8: noqa
import sys
import unittest
from dataclasses import dataclass, field, InitVar
from enum import Enum
from typing import List, Optional
import torch
from hypothesis import given, settings, strategies as st
from torch.testing._internal.jit_utils import JitTestCase
from typing import List, Optional
import sys
import torch
import unittest
from enum import Enum
# Example jittable dataclass
@dataclass(order=True)
@ -20,8 +22,8 @@ class Point:
def __post_init__(self):
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
class MixupScheme(Enum):
class MixupScheme(Enum):
INPUT = ["input"]
MANIFOLD = [
@ -38,6 +40,7 @@ class MixupParams:
self.alpha = alpha
self.scheme = scheme
class MixupScheme2(Enum):
A = 1
B = 2
@ -49,6 +52,7 @@ class MixupParams2:
self.alpha = alpha
self.scheme = scheme
@dataclass
class MixupParams3:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
@ -59,11 +63,11 @@ class MixupParams3:
# Make sure the Meta internal tooling doesn't raise an overflow error
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
class TestDataclasses(JitTestCase):
class TestDataclasses(JitTestCase):
@classmethod
def tearDownClass(cls):
torch._C._jit_clear_class_registry()
torch._C._jit_clear_class_registry()
def test_init_vars(self):
@torch.jit.script
@ -75,7 +79,9 @@ class TestDataclasses(JitTestCase):
norm: Optional[torch.Tensor] = None
def __post_init__(self, norm_p: int):
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
self.norm = (
torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
) ** (1 / norm_p)
def fn(x: float, y: float, p: int):
pt = Point2(x, y, p)
@ -88,6 +94,7 @@ class TestDataclasses(JitTestCase):
@given(NonHugeFloats, NonHugeFloats)
def test__post_init__(self, x, y):
P = torch.jit.script(Point)
def fn(x: float, y: float):
pt = P(x, y)
return pt.norm
@ -95,7 +102,9 @@ class TestDataclasses(JitTestCase):
self.checkScript(fn, [x, y])
@settings(deadline=None)
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
@given(
st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
)
def test_comparators(self, pt1, pt2):
x1, y1 = pt1
x2, y2 = pt2
@ -122,6 +131,7 @@ class TestDataclasses(JitTestCase):
with self.assertRaises(NotImplementedError):
torch.jit.script(Foo)
def fn():
foo = Foo()
return foo.x
@ -137,7 +147,7 @@ class TestDataclasses(JitTestCase):
a: int
b: int
def __eq__(self, other: 'CustomEq') -> bool:
def __eq__(self, other: "CustomEq") -> bool:
return self.a == other.a # ignore the b field
def fn(a: int, b1: int, b2: int):
@ -154,9 +164,7 @@ class TestDataclasses(JitTestCase):
torch.jit.script(MixupParams2) # don't throw
def test_use_unregistered_dataclass_raises(self):
def f(a: MixupParams3):
return 0

View File

@ -1,12 +1,12 @@
# Owner(s): ["oncall: jit"]
from itertools import product
import unittest
from itertools import product
import torch
from torch.jit._passes._property_propagation import apply_input_props_using_example
from torch.testing._internal.common_utils import TEST_CUDA
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit._passes._property_propagation import apply_input_props_using_example
try:
from torchvision import models

View File

@ -7,19 +7,19 @@ from unittest.case import expectedFailure
import torch
from torch import complex32, float32, float64, int32, int64
from torch.jit._passes import _property_propagation
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import (
SampleInput,
op_db,
sample_inputs_adaptive_avg_pool2d,
sample_inputs_conv2d,
SampleInput,
)
from torch.testing._internal.common_utils import set_default_dtype, first_sample
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import first_sample, set_default_dtype
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing._internal.common_device_type import (
ops,
instantiate_device_type_tests,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.jit_utils import JitTestCase
"""
Dtype Analysis relies on symbolic shape analysis, which is still in beta
@ -274,7 +274,9 @@ class TestDtypeAnalysis(TestDtypeBase):
):
for dtype in (torch.int8, torch.float64):
# Gets default version for conv2d
sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[-1]
sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[
-1
]
input_args = [sample_input.input, *sample_input.args]
self.assert_dtype_equal_custom_args(fn, input_args)
@ -352,7 +354,9 @@ class TestDtypeCustomRules(TestDtypeBase):
# Run the Dtype Analysis
graph = traced_fn.graph # Note this is a cached graph
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
input_tensors += [v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)]
input_tensors += [
v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)
]
self.prop_dtype_on_graph(graph, input_tensors)
self.assert_output_dtype_equal(expected_res, graph)

View File

@ -2,21 +2,24 @@
import os
import sys
from enum import Enum
from typing import Any, List
import torch
from torch.testing import FileCheck
from enum import Enum
from typing import Any, List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestEnum(JitTestCase):
def test_enum_value_types(self):
@ -38,11 +41,9 @@ class TestEnum(JitTestCase):
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name)
FileCheck() \
.check("IntEnum") \
.check("FloatEnum") \
.check("StringEnum") \
.run(str(supported_enum_types.graph))
FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
str(supported_enum_types.graph)
)
class TensorEnum(Enum):
FOO = torch.tensor(0)
@ -54,7 +55,9 @@ class TestEnum(JitTestCase):
return a.name
# TODO: rewrite code so that the highlight is not empty.
with self.assertRaisesRegexWithHighlight(RuntimeError, "Cannot create Enum with value type 'Tensor'", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
):
torch.jit.script(unsupported_enum_types)
def test_enum_comp(self):
@ -88,11 +91,9 @@ class TestEnum(JitTestCase):
def enum_comp(x: Foo) -> bool:
return x == Bar.ITEM1
FileCheck() \
.check("prim::Constant") \
.check_same("Bar.ITEM1") \
.check("aten::eq") \
.run(str(enum_comp.graph))
FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
"aten::eq"
).run(str(enum_comp.graph))
self.assertEqual(enum_comp(Foo.ITEM1), False)
@ -107,7 +108,9 @@ class TestEnum(JitTestCase):
return x == y
# TODO: rewrite code so that the highlight is not empty.
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not unify type list", ""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Could not unify type list", ""
):
torch.jit.script(enum_comp)
def test_enum_name(self):
@ -121,11 +124,9 @@ class TestEnum(JitTestCase):
def enum_name(x: Color) -> str:
return x.name
FileCheck() \
.check("Color") \
.check_next("prim::EnumName") \
.check_next("return") \
.run(str(enum_name.graph))
FileCheck().check("Color").check_next("prim::EnumName").check_next(
"return"
).run(str(enum_name.graph))
self.assertEqual(enum_name(Color.RED), Color.RED.name)
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
@ -141,11 +142,9 @@ class TestEnum(JitTestCase):
def enum_value(x: Color) -> int:
return x.value
FileCheck() \
.check("Color") \
.check_next("prim::EnumValue") \
.check_next("return") \
.run(str(enum_value.graph))
FileCheck().check("Color").check_next("prim::EnumValue").check_next(
"return"
).run(str(enum_value.graph))
self.assertEqual(enum_value(Color.RED), Color.RED.value)
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
@ -161,11 +160,9 @@ class TestEnum(JitTestCase):
def enum_const(x: Color) -> bool:
return x == Color.RED
FileCheck() \
.check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
.check_next("aten::eq") \
.check_next("return") \
.run(str(enum_const.graph))
FileCheck().check(
"prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
self.assertEqual(enum_const(Color.RED), True)
self.assertEqual(enum_const(Color.GREEN), False)
@ -183,7 +180,9 @@ class TestEnum(JitTestCase):
else:
return False
with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
):
torch.jit.script(enum_const)
def test_enum_ivalue_type(self):
@ -197,10 +196,9 @@ class TestEnum(JitTestCase):
def is_color_enum(x: Any):
return isinstance(x, Color)
FileCheck() \
.check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
.check_next("return") \
.run(str(is_color_enum.graph))
FileCheck().check(
"prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
).check_next("return").run(str(is_color_enum.graph))
self.assertEqual(is_color_enum(Color.RED), True)
self.assertEqual(is_color_enum(Color.GREEN), True)
@ -217,10 +215,9 @@ class TestEnum(JitTestCase):
def closed_over_aliased_type():
return a.RED.value
FileCheck() \
.check("prim::Constant[value={}]".format(a.RED.value)) \
.check_next("return") \
.run(str(closed_over_aliased_type.graph))
FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
"return"
).run(str(closed_over_aliased_type.graph))
self.assertEqual(closed_over_aliased_type(), Color.RED.value)
@ -230,10 +227,9 @@ class TestEnum(JitTestCase):
def closed_over_aliased_value():
return b.value
FileCheck() \
.check("prim::Constant[value={}]".format(b.value)) \
.check_next("return") \
.run(str(closed_over_aliased_value.graph))
FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
"return"
).run(str(closed_over_aliased_value.graph))
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
@ -253,13 +249,9 @@ class TestEnum(JitTestCase):
m = TestModule(Color.RED)
scripted = torch.jit.script(m)
FileCheck() \
.check("TestModule") \
.check_next("Color") \
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("prim::EnumValue") \
.check_next("return") \
.run(str(scripted.graph))
FileCheck().check("TestModule").check_next("Color").check_same(
'prim::GetAttr[name="e"]'
).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED.value)
@ -316,16 +308,12 @@ class TestEnum(JitTestCase):
m = TestModule(Color.RED)
scripted = torch.jit.script(m)
FileCheck() \
.check("TestModule") \
.check_next("Color") \
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("return") \
.run(str(scripted.graph))
FileCheck().check("TestModule").check_next("Color").check_same(
'prim::GetAttr[name="e"]'
).check_next("return").run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED)
def test_enum_iterate(self):
class Color(Enum):
RED = 1
@ -342,12 +330,9 @@ class TestEnum(JitTestCase):
make_global(Color)
scripted = torch.jit.script(iterate_enum)
FileCheck() \
.check("Enum<__torch__.jit.test_enum.Color>[]") \
.check_same("Color.RED") \
.check_same("Color.GREEN") \
.check_same("Color.BLUE") \
.run(str(scripted.graph))
FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
"Color.RED"
).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
# PURPLE always appears last because we follow Python's Enum definition order.
self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
@ -355,7 +340,6 @@ class TestEnum(JitTestCase):
# Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
def test_enum_explicit_script(self):
@torch.jit.script
class Color(Enum):
RED = 1

View File

@ -1,11 +1,13 @@
# Owner(s): ["oncall: jit"]
from torch.testing._internal.common_utils import TestCase
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase
r"""
Test TorchScript exception handling.
"""
class TestException(TestCase):
def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule):
@ -16,31 +18,40 @@ class TestException(TestCase):
@torch.jit.script_method
def forward(self, x):
return self.conv(x)
foo = Foo()
# testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
with self.assertRaisesRegex(
RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
):
foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script
def close_match(x):
return x.masked_fill(True)
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
"supported in TorchScript"):
with self.assertRaisesRegex(
RuntimeError,
"This op may not exist or may not be currently " "supported in TorchScript",
):
@torch.jit.script
def unknown_op(x):
torch.set_anomaly_enabled(True)
return x
def test_exceptions(self):
cu = torch.jit.CompilationUnit('''
cu = torch.jit.CompilationUnit(
"""
def foo(cond):
if bool(cond):
raise ValueError(3)
return 1
''')
"""
)
cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"):
@ -97,6 +108,7 @@ class TestException(TestCase):
else:
raise Exception("Hi")
return a
self.assertEqual(foo(), 1)
@torch.jit.script
@ -114,11 +126,13 @@ class TestException(TestCase):
no_message()
def test_assertions(self):
cu = torch.jit.CompilationUnit('''
cu = torch.jit.CompilationUnit(
"""
def foo(cond):
assert bool(cond), "hi"
return 0
''')
"""
)
cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
@ -142,7 +156,9 @@ class TestException(TestCase):
def fn(x):
return python_op(x)
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
with self.assertRaisesRegex(
RuntimeError, "operation failed in the TorchScript interpreter"
):
fn(torch.tensor(4))
def test_dict_expansion_raises_error(self):
@ -150,8 +166,9 @@ class TestException(TestCase):
d = {"foo": 1, "bar": 2, "baz": 3}
return {**d}
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
"Dict expansion "):
with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError, "Dict expansion "
):
torch.jit.script(fn)
def test_custom_python_exception(self):
@ -162,7 +179,9 @@ class TestException(TestCase):
def fn():
raise MyValueError("test custom exception")
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
with self.assertRaisesRegex(
torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
):
fn()
def test_custom_python_exception_defined_elsewhere(self):
@ -171,5 +190,9 @@ class TestException(TestCase):
@torch.jit.script
def fn():
raise MyKeyError("This is a user defined key error")
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
with self.assertRaisesRegex(
torch.jit.Error,
"jit.myexception.MyKeyError: This is a user defined key error",
):
fn()

File diff suppressed because it is too large Load Diff

View File

@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestFunctionalBlocks(JitTestCase):
def test_subgraph_creation(self):
@ -30,14 +33,22 @@ class TestFunctionalBlocks(JitTestCase):
return x + y + z
graph = torch.jit.script(fn).graph
self.run_pass('create_functional_graphs', graph)
self.run_pass("create_functional_graphs", graph)
# all uses of x and y should be sunk
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph)
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph)
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(
r"%x"
).run(graph)
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(
r"%y"
).run(graph)
# Don't allow any outputs which escape scope, so there is one final addition in the graph
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph)
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(
graph
)
# z + 1, z.add_(2) considered non functional, z = z * z should be considered functional
FileCheck().check("add").check("add_").check_not("mul").check("FunctionalGraph").run(graph)
FileCheck().check("add").check("add_").check_not("mul").check(
"FunctionalGraph"
).run(graph)

View File

@ -3,9 +3,11 @@
import torch
from torch.testing._internal.jit_utils import JitTestCase
class TestFuserCommon(JitTestCase):
def test_autodiff_fallback(self):
for rq in [True, False]:
@torch.jit.script
def fn(x):
return torch.max(x**2.0, x**3.0)

View File

@ -1,9 +1,9 @@
# Owner(s): ["oncall: jit"]
from torch.testing._internal.jit_utils import JitTestCase
import torch
import torch._C
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
class TestGraphRewritePasses(JitTestCase):

View File

@ -3,9 +3,9 @@
import os
import sys
import torch
from typing import List, Tuple
from typing import Tuple, List
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -13,9 +13,12 @@ sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestHash(JitTestCase):
def test_hash_tuple(self):
@ -38,6 +41,7 @@ class TestHash(JitTestCase):
def test_hash_tensor(self):
"""Tensors should hash by identity"""
def fn(t1, t2):
return hash(t1) == hash(t2)
@ -74,7 +78,7 @@ class TestHash(JitTestCase):
self.checkScript(fn, (1.2345, 6.789))
self.checkScript(fn, (1.2345, float("inf")))
self.checkScript(fn, (float("inf"), float("inf")))
self.checkScript(fn, (1.2345, float('nan')))
self.checkScript(fn, (1.2345, float("nan")))
if sys.version_info < (3, 10):
# Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
# Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
@ -103,9 +107,9 @@ class TestHash(JitTestCase):
def fn(d1: torch.device, d2: torch.device):
return hash(d1) == hash(d2)
gpu0 = torch.device('cuda:0')
gpu1 = torch.device('cuda:1')
cpu = torch.device('cpu')
gpu0 = torch.device("cuda:0")
gpu1 = torch.device("cuda:1")
cpu = torch.device("cpu")
self.checkScript(fn, (gpu0, gpu0))
self.checkScript(fn, (gpu0, gpu1))
self.checkScript(fn, (gpu0, cpu))

View File

@ -6,21 +6,29 @@ import unittest
from typing import Tuple
import torch
from jit.test_hooks_modules import (
ModuleDirectforwardSubmodCall, ModuleForwardSingleInput,
ModuleForwardTupleInput, create_forward_tuple_input,
create_module_forward_multiple_inputs, create_module_forward_single_input,
create_forward_tuple_input,
create_module_forward_multiple_inputs,
create_module_forward_single_input,
create_module_hook_return_nothing,
create_module_multiple_hooks_multiple_inputs,
create_module_multiple_hooks_single_input, create_module_no_forward_input,
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs,
create_module_multiple_hooks_single_input,
create_module_no_forward_input,
create_module_same_hook_repeated,
create_submodule_forward_multiple_inputs,
create_submodule_forward_single_input,
create_submodule_forward_single_input_return_not_tupled,
create_submodule_hook_return_nothing,
create_submodule_multiple_hooks_multiple_inputs,
create_submodule_multiple_hooks_single_input,
create_submodule_no_forward_input, create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks)
create_submodule_no_forward_input,
create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks,
ModuleDirectforwardSubmodCall,
ModuleForwardSingleInput,
ModuleForwardTupleInput,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -37,7 +45,6 @@ if __name__ == "__main__":
# Tests for JIT forward hooks and pre-hooks
class TestHooks(JitTestCase):
def test_module_no_forward_input(self):
self.checkModule(create_module_no_forward_input(), ())
@ -73,7 +80,8 @@ class TestHooks(JitTestCase):
def test_submodule_multiple_hooks_multiple_inputs(self):
self.checkModule(
create_submodule_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook"),
create_submodule_multiple_hooks_multiple_inputs(),
(["a"], "no_pre_hook"),
)
def test_submodule_forward_single_input(self):
@ -242,7 +250,8 @@ class TestHooks(JitTestCase):
m.register_forward_pre_hook(pre_hook_wrong_input1)
with self.assertRaisesRegex(
RuntimeError, "has the wrong inner types for the input tuple argument",
RuntimeError,
"has the wrong inner types for the input tuple argument",
):
torch.jit.script(m)
@ -278,7 +287,8 @@ class TestHooks(JitTestCase):
m.register_forward_pre_hook(pre_hook_wrong_output)
with self.assertRaisesRegex(
RuntimeError, "returned the wrong type of: 'int'",
RuntimeError,
"returned the wrong type of: 'int'",
):
torch.jit.script(m)

View File

@ -1,8 +1,9 @@
# Owner(s): ["oncall: jit"]
import torch
from typing import List, Tuple
import torch
class SubmoduleNoForwardInputs(torch.nn.Module):
def __init__(self, name):

View File

@ -2,6 +2,7 @@
import os
import sys
import torch
from torch._C import parse_ir
from torch.testing import FileCheck
@ -11,10 +12,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript
class TestIgnorableArgs(JitTestCase):
@ -44,11 +48,14 @@ class TestIgnorableArgs(JitTestCase):
# We ignore trailing arguments after start=2 for dim 0
# and after end=1 for dim 1
# because in %16, %15 and %0 are default values for the schema.
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src)
FileCheck().check(
"torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)"
).run(src)
self.assertEqual(function(), function_copy())
def test_add_out_ignorable_args(self):
@torch.jit.script
def fn(x: torch.Tensor, y: torch.Tensor):
torch.add(x, y, out=y)
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)

View File

@ -9,13 +9,16 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestIgnoreContextManager(JitTestCase):
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
@ -26,11 +29,14 @@ class TestIgnoreContextManager(JitTestCase):
b: int = 5
c: int = 0
d: int = 6
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"):
with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int", d="out:int"
):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
d = 9
return c + d
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), model())
@ -41,10 +47,13 @@ class TestIgnoreContextManager(JitTestCase):
a: int = 4
b: int = 5
c: int = 0
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"):
with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int"
):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
return c
model = B()
s = torch.jit.script(model)
self.assertEqual(s(), 11)
@ -58,6 +67,7 @@ class TestIgnoreContextManager(JitTestCase):
l = [2 for i in range(a) if i > 2]
b = l[0] + a
return b
model = C()
s = torch.jit.script(model)
self.assertEqual(s(), 6)
@ -72,6 +82,7 @@ class TestIgnoreContextManager(JitTestCase):
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
l = [2 + b for i in range(a) if i > 2]
return a
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 4)
@ -85,6 +96,7 @@ class TestIgnoreContextManager(JitTestCase):
c = [2 for i in range(7) if i > 2]
c[0] = 3
return c[0] + c[1]
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 5)

View File

@ -2,10 +2,10 @@
import os
import sys
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch
import warnings
from typing import List, Any, Dict, Tuple, Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -19,6 +19,7 @@ if __name__ == "__main__":
"instead."
)
# Tests for torch.jit.isinstance
class TestIsinstance(JitTestCase):
def test_int(self):
@ -223,28 +224,42 @@ class TestIsinstance(JitTestCase):
x = ["1", "2", "3"]
err_msg = "Attempted to use List without a contained type. " \
err_msg = (
"Attempted to use List without a contained type. "
r"Please add a contained type, e.g. List\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(list_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
list_no_contained_type(x)
def test_tuple_no_contained_type(self):
def tuple_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Tuple)
x = ("1", "2", "3")
err_msg = "Attempted to use Tuple without a contained type. " \
err_msg = (
"Attempted to use Tuple without a contained type. "
r"Please add a contained type, e.g. Tuple\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(tuple_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
tuple_no_contained_type(x)
def test_optional_no_contained_type(self):
@ -253,12 +268,20 @@ class TestIsinstance(JitTestCase):
x = ("1", "2", "3")
err_msg = "Attempted to use Optional without a contained type. " \
err_msg = (
"Attempted to use Optional without a contained type. "
r"Please add a contained type, e.g. Optional\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(optional_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
optional_no_contained_type(x)
def test_dict_no_contained_type(self):
@ -267,12 +290,20 @@ class TestIsinstance(JitTestCase):
x = {"a": "aa"}
err_msg = "Attempted to use Dict without contained types. " \
err_msg = (
"Attempted to use Dict without contained types. "
r"Please add contained type, e.g. Dict\[int, int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(dict_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,):
with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
dict_no_contained_type(x)
def test_tuple_rhs(self):

View File

@ -13,10 +13,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase):
@ -24,58 +27,71 @@ class TestJitUtils(JitTestCase):
def test_get_callable_argument_names_positional_or_keyword(self):
def fn_positional_or_keyword_args_only(x, y):
return x + y
self.assertEqual(
["x", "y"],
torch._jit_internal.get_callable_argument_names(fn_positional_or_keyword_args_only))
torch._jit_internal.get_callable_argument_names(
fn_positional_or_keyword_args_only
),
)
# Tests that POSITIONAL_ONLY arguments are ignored.
def test_get_callable_argument_names_positional_only(self):
code = dedent('''
code = dedent(
"""
def fn_positional_only_arg(x, /, y):
return x + y
''')
"""
)
fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
self.assertEqual(
["y"],
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
)
# Tests that VAR_POSITIONAL arguments are ignored.
def test_get_callable_argument_names_var_positional(self):
# Tests that VAR_POSITIONAL arguments are ignored.
def fn_var_positional_arg(x, *arg):
return x + arg[0]
self.assertEqual(
["x"],
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
)
# Tests that KEYWORD_ONLY arguments are ignored.
def test_get_callable_argument_names_keyword_only(self):
def fn_keyword_only_arg(x, *, y):
return x + y
self.assertEqual(
["x"],
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))
["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
)
# Tests that VAR_KEYWORD arguments are ignored.
def test_get_callable_argument_names_var_keyword(self):
def fn_var_keyword_arg(**args):
return args['x'] + args['y']
return args["x"] + args["y"]
self.assertEqual(
[],
torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg))
[], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
)
# Tests that a function signature containing various different types of
# arguments are ignored.
def test_get_callable_argument_names_hybrid(self):
code = dedent('''
code = dedent(
"""
def fn_hybrid_args(x, /, y, *args, **kwargs):
return x + y + args[0] + kwargs['z']
''')
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
"""
)
fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
self.assertEqual(
["y"],
torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
)
def test_checkscriptassertraisesregex(self):
def fn():
@ -84,22 +100,18 @@ class TestJitUtils(JitTestCase):
self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
s = dedent("""
s = dedent(
"""
def fn():
tup = (1, 2)
return tup[2]
""")
"""
)
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
def test_no_tracer_warn_context_manager(self):
torch._C._jit_set_tracer_state_warn(True)
with jit_utils.NoTracerWarnContextManager() as no_warn:
self.assertEqual(
False,
torch._C._jit_get_tracer_state_warn()
)
self.assertEqual(
True,
torch._C._jit_get_tracer_state_warn()
)
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())

File diff suppressed because it is too large Load Diff

View File

@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
@ -22,30 +25,29 @@ class TestLogging(JitTestCase):
def forward(self, x):
for i in range(x.size(0)):
x += 1.0
torch.jit._logging.add_stat_value('foo', 1)
torch.jit._logging.add_stat_value("foo", 1)
if bool(x.sum() > 0.0):
torch.jit._logging.add_stat_value('positive', 1)
torch.jit._logging.add_stat_value("positive", 1)
else:
torch.jit._logging.add_stat_value('negative', 1)
torch.jit._logging.add_stat_value("negative", 1)
return x
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtl = ModuleThatLogs()
for i in range(5):
mtl(torch.rand(3, 4, 5))
self.assertEqual(logger.get_counter_val('foo'), 15)
self.assertEqual(logger.get_counter_val('positive'), 5)
self.assertEqual(logger.get_counter_val("foo"), 15)
self.assertEqual(logger.get_counter_val("positive"), 5)
finally:
torch.jit._logging.set_logger(old_logger)
def test_trace_numeric_counter(self):
def foo(x):
torch.jit._logging.add_stat_value('foo', 1)
torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
@ -54,7 +56,7 @@ class TestLogging(JitTestCase):
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
self.assertEqual(logger.get_counter_val("foo"), 1)
finally:
torch.jit._logging.set_logger(old_logger)
@ -65,7 +67,7 @@ class TestLogging(JitTestCase):
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x
mtm = ModuleThatTimes()
@ -73,7 +75,7 @@ class TestLogging(JitTestCase):
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally:
torch.jit._logging.set_logger(old_logger)
@ -85,7 +87,7 @@ class TestLogging(JitTestCase):
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x
mtm = ModuleThatTimes()
@ -93,27 +95,27 @@ class TestLogging(JitTestCase):
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_counter_aggregation(self):
def foo(x):
for i in range(3):
torch.jit._logging.add_stat_value('foo', 1)
torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
self.assertEqual(logger.get_counter_val("foo"), 1)
finally:
torch.jit._logging.set_logger(old_logger)
def test_logging_levels_set(self):
torch._C._jit_set_logging_option('foo')
self.assertEqual('foo', torch._C._jit_get_logging_option())
torch._C._jit_set_logging_option("foo")
self.assertEqual("foo", torch._C._jit_get_logging_option())

View File

@ -1,28 +1,32 @@
# Owner(s): ["oncall: jit"]
from typing import Any, Dict, List, Optional, Tuple
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing import FileCheck
from torch import jit
from jit.test_module_interface import TestModuleInterface # noqa: F401
import os
import sys
import torch
import torch.testing._internal.jit_utils
import torch.nn as nn
import unittest
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.testing._internal.jit_utils
from torch import jit
from torch.testing import FileCheck
from torch.testing._internal.common_utils import freeze_rng_state
from torch.testing._internal.jit_utils import RUN_CUDA_HALF
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
from jit.test_module_interface import TestModuleInterface # noqa: F401
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestMisc(JitTestCase):
def test_joined_str(self):
@ -30,12 +34,12 @@ class TestMisc(JitTestCase):
hello, test = "Hello", "test"
print(f"{hello + ' ' + test}, I'm a {test}")
print("format blank")
hi = 'hi'
hi = "hi"
print(f"stuff before {hi}")
print(f"{hi} stuff after")
return x + 1
x = torch.arange(4., requires_grad=True)
x = torch.arange(4.0, requires_grad=True)
# TODO: Add support for f-strings in string parser frontend
# self.checkScript(func, [x], optimize=True, capture_output=True)
@ -50,10 +54,14 @@ class TestMisc(JitTestCase):
self.assertEqual(captured, captured_script)
def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError, "variable number of arguments"
):
class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str = 2):
pass
torch.jit.script(M())
class M(torch.nn.Module):
@ -62,32 +70,35 @@ class TestMisc(JitTestCase):
sm = torch.jit.script(M())
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
with self.assertRaisesRegex(
RuntimeError, "missing value for argument 'n_tokens'"
):
sm()
with self.assertRaisesRegex(RuntimeError, "positional arg"):
sm(3, 'hello')
sm(3, "hello")
self.assertEqual(sm(n_tokens=3, device_name='hello'), (3, 'hello'))
self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
def test_tuple_subscripted_assign(self):
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
@torch.jit.script
def foo(a: Tuple[int, int]) -> None:
a[0] = a[1]
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
@torch.jit.script
def bar(a: Tuple[int, int]) -> None:
a[0] += a[1]
def test_subexpression_List_Future(self):
@torch.jit.script
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
return x[0]
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph)
FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
def test_subexpression_Future_annotate(self):
@torch.jit.script
@ -110,36 +121,40 @@ class TestMisc(JitTestCase):
if isinstance(x, str):
return x
return "foo"
forward = torch.jit.script(forward)
self.assertEqual(forward(1), "foo")
self.assertEqual(forward("bar"), "bar")
def test_subexpression_Tuple_int_int_Future(self):
@torch.jit.script
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]:
def fn(
x: Tuple[int, int, torch.jit.Future[int]]
) -> Tuple[int, torch.jit.Future[int]]:
return x[0], x[2]
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph)
FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
fn.graph
)
def test_subexpression_Dict_int_Future(self):
@torch.jit.script
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
return x[y]
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph)
FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
def test_subexpression_Optional(self):
@torch.jit.script
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]:
def fn(
x: Optional[Dict[int, torch.jit.Future[int]]]
) -> Optional[torch.jit.Future[int]]:
if x is not None:
return x[0]
else:
return None
FileCheck().check('Dict(int, Future(int))?').run(fn.graph)
FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
def test_if_returning_any(self):
"""
@ -147,6 +162,7 @@ class TestMisc(JitTestCase):
types early from each branch when the return
type of the function is Any.
"""
def if_function(inp: torch.Tensor) -> Any:
if inp.shape[0] == 1:
return inp * inp
@ -156,14 +172,23 @@ class TestMisc(JitTestCase):
self.checkScript(if_function, (torch.randn(5),))
def test_hacked_twin(self):
def gen_data():
with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
input, index, value, = gen_data()
input1, index1, value1, = gen_data()
out1 = torch.ops.aten.index_put.hacked_twin(input, [index], value, accumulate=False)
(
input,
index,
value,
) = gen_data()
(
input1,
index1,
value1,
) = gen_data()
out1 = torch.ops.aten.index_put.hacked_twin(
input, [index], value, accumulate=False
)
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2)
@ -172,14 +197,23 @@ class TestMisc(JitTestCase):
self.assertEqual(input, input1)
def test_unsafe_hacked_twin(self):
def gen_data():
with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
input, index, value, = gen_data()
input1, index1, value1, = gen_data()
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
(
input,
index,
value,
) = gen_data()
(
input1,
index1,
value1,
) = gen_data()
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
input, [index], value, accumulate=False
)
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2)
@ -188,7 +222,9 @@ class TestMisc(JitTestCase):
self.assertEqual(input, input1)
def index_put_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
return torch.ops.aten._unsafe_index_put(
input, [index], value, accumulate=False
)
input2, index2, value2 = gen_data()
script_index_put_fn = torch.jit.script(index_put_fn)
@ -197,7 +233,9 @@ class TestMisc(JitTestCase):
self.assertEqual(expect, actual)
def index_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
return torch.ops.aten._unsafe_index_put(
input, [index], value, accumulate=False
)
script_index_fn = torch.jit.script(index_fn)
expect = index_fn(input2.clone(), index2, value2)
@ -205,7 +243,6 @@ class TestMisc(JitTestCase):
self.assertEqual(expect, actual)
def test_export_opnames_interface(self):
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@ -240,7 +277,7 @@ class TestMisc(JitTestCase):
make_global(OneTwoModule)
class M(nn.Module):
sub : OneTwoModule
sub: OneTwoModule
def __init__(self):
super().__init__()
@ -254,12 +291,18 @@ class TestMisc(JitTestCase):
torch._C._enable_mobile_interface_call_export()
scripted_M_mod = torch.jit.script(M())
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))
self.assertTrue(
{"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))
)
)
scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))))
self.assertTrue(
{"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))
)
)
def test_math_inf(self):
from math import inf
@ -292,7 +335,6 @@ class TestMisc(JitTestCase):
with self.assertRaises(RuntimeError):
torch.jit.script(non_temporary_fail)
@torch.jit.script
def test_return():
return []
@ -335,7 +377,9 @@ class TestMisc(JitTestCase):
def multiple_args():
return torch.LongTensor(1, [2])
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
with self.assertRaisesRegex(
RuntimeError, "multiple positional arguments that were not all integers"
):
torch.jit.script(multiple_args)
# kwarg bad schema
@ -345,7 +389,6 @@ class TestMisc(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "hello"):
torch.jit.script(bad_kwarg)
def test_broadcasting_list(self):
"""
Test BroadcastingList and torch.nn._size_N_t alias
@ -360,7 +403,7 @@ class TestMisc(JitTestCase):
return x[0] + x[1]
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
def test_parse_ir_annotate(self):
ir = """
@ -397,7 +440,6 @@ class TestMisc(JitTestCase):
self.assertTrue(ret.numel() == 1)
self.assertTrue(len(ret.size()) == 1)
def test_script_many_decorators(self):
def no_op_decorator(f):
return f
@ -410,7 +452,9 @@ class TestMisc(JitTestCase):
def foo(x, dim: int):
return x.unsqueeze(dim)
x = torch.randn(1,)
x = torch.randn(
1,
)
expected = foo(x, 0)
scripted = torch.jit.script(foo)
actual = scripted(x, 0)
@ -421,10 +465,10 @@ class TestMisc(JitTestCase):
# https://github.com/pytorch/pytorch/issues/75476
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
p = torch.sigmoid(p)
result = p ** gamma
result = p**gamma
return result
x = torch.rand((2, 2), dtype=torch.half, device='cuda')
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
ref = fn(x)
@ -450,8 +494,12 @@ class TestMisc(JitTestCase):
# We want "Scalar" to come before "complex".
op, override_names = torch._C._jit_get_operation("aten::add")
print(override_names)
complex_indices = [i for i, name in enumerate(override_names) if name == "complex"]
Scalar_indices = [i for i, name in enumerate(override_names) if name == "Scalar"]
complex_indices = [
i for i, name in enumerate(override_names) if name == "complex"
]
Scalar_indices = [
i for i, name in enumerate(override_names) if name == "Scalar"
]
self.assertTrue(len(complex_indices) > 0)
self.assertTrue(len(Scalar_indices) > 0)

View File

@ -3,27 +3,33 @@
import os
import sys
import unittest
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, ProfilingMode,
set_default_dtype,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
ProfilingMode,
set_default_dtype,
)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
@ -31,6 +37,7 @@ except RuntimeError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
class MnistNet(nn.Module):
def __init__(self):
super().__init__()
@ -49,6 +56,7 @@ class MnistNet(nn.Module):
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class TestModels(JitTestCase):
@staticmethod
def _test_dcgan_models(self, device, check_export_import=True):
@ -102,31 +110,38 @@ class TestModels(JitTestCase):
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
nn.Sigmoid(),
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
(torch.rand(bs, nz, 1, 1, device=device),),
export_import=check_export_import)
example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
export_import=check_export_import)
self.checkTrace(
DCGANGenerator(nz, ngf, nc).to(device),
(torch.rand(bs, nz, 1, 1, device=device),),
export_import=check_export_import,
)
example_input = DCGANGenerator(nz, ngf, nc).to(device)(
torch.rand(bs, nz, 1, 1, device=device)
)
self.checkTrace(
DCGANDiscriminator(nc, ndf).to(device),
(example_input,),
export_import=check_export_import,
)
def test_dcgan_models(self):
# Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double):
self._test_dcgan_models(self, device='cpu')
self._test_dcgan_models(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_dcgan_models_cuda(self):
# Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_dcgan_models(self, device='cuda', check_export_import=False)
self._test_dcgan_models(self, device="cuda", check_export_import=False)
@staticmethod
def _test_neural_style(self, device, check_export_import=True):
@ -147,9 +162,13 @@ class TestModels(JitTestCase):
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.deconv1 = UpsampleConvLayer(
128, 64, kernel_size=3, stride=1, upsample=2
)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.deconv2 = UpsampleConvLayer(
64, 32, kernel_size=3, stride=1, upsample=2
)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
@ -174,7 +193,9 @@ class TestModels(JitTestCase):
super().__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.conv2d = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride
)
def forward(self, x):
out = self.reflection_pad(x)
@ -209,14 +230,20 @@ class TestModels(JitTestCase):
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
def __init__(
self, in_channels, out_channels, kernel_size, stride, upsample=None
):
super().__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
self.upsample_layer = torch.nn.Upsample(
mode="nearest", scale_factor=upsample
)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.conv2d = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride
)
def forward(self, x):
x_in = x
@ -226,44 +253,54 @@ class TestModels(JitTestCase):
out = self.conv2d(out)
return out
self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
self.checkTrace(
TransformerNet(),
(torch.rand(5, 3, 16, 16),),
export_import=check_export_import,
)
@slowTest
def test_neural_style(self):
self._test_neural_style(self, device='cpu')
self._test_neural_style(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_neural_style_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_neural_style(self, device='cuda', check_export_import=False)
self._test_neural_style(self, device="cuda", check_export_import=False)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor")
@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
)
@staticmethod
def _test_mnist(self, device, check_export_import=True):
# eval() is present because dropout makes this nondeterministic
with enable_profiling_mode_for_profiling_tests():
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
export_import=check_export_import)
self.checkTrace(
MnistNet().to(device).eval(),
(torch.rand(5, 1, 28, 28, device=device),),
export_import=check_export_import,
)
def test_mnist(self):
self._test_mnist(self, device='cpu')
self._test_mnist(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_mnist(self, device='cuda', check_export_import=False)
self._test_mnist(self, device="cuda", check_export_import=False)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_training_leaks_no_memory_cuda(self):
net = MnistNet().cuda()
# MnistNet uses dropout, don't check its trace
traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
check_trace=False)
traced_net = torch.jit.trace(
net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
)
def train(iters):
for _ in range(iters):
# Get some fake data
inp = torch.randn(5, 1, 28, 28, device='cuda')
inp = torch.randn(5, 1, 28, 28, device="cuda")
out = traced_net(inp)
# Here's some fake loss
@ -292,21 +329,23 @@ class TestModels(JitTestCase):
return F.softmax(action_scores, dim=1)
with enable_profiling_mode_for_profiling_tests():
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
export_import=test_export_import)
self.checkTrace(
Policy().to(device),
(torch.rand(1, 4, device=device),),
export_import=test_export_import,
)
def test_reinforcement_learning(self):
self._test_reinforcement_learning(self, device='cpu')
self._test_reinforcement_learning(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_reinforcement_learning_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
@staticmethod
def _test_snli(self, device, check_export_import=True):
class Bottle(nn.Module):
def forward(self, input):
if len(input.size()) <= 2:
return super().forward(input)
@ -318,25 +357,31 @@ class TestModels(JitTestCase):
pass
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
input_size = config.d_proj if config.projection else config.d_embed
dropout = 0 if config.n_layers == 1 else config.dp_ratio
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
num_layers=config.n_layers, dropout=dropout,
bidirectional=config.birnn)
self.rnn = nn.LSTM(
input_size=input_size,
hidden_size=config.d_hidden,
num_layers=config.n_layers,
dropout=dropout,
bidirectional=config.birnn,
)
def forward(self, inputs):
batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = inputs.new_zeros(state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
return (
ht[-1]
if not self.config.birnn
else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
)
class SNLIClassifier(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
@ -359,7 +404,8 @@ class TestModels(JitTestCase):
Linear(*lin_config),
self.relu,
self.dropout,
Linear(seq_in_size, config.d_out))
Linear(seq_in_size, config.d_out),
)
def forward(self, premise, hypothesis):
prem_embed = self.embed(premise)
@ -391,22 +437,25 @@ class TestModels(JitTestCase):
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
inputs_require_grads=False, export_import=check_export_import)
self.checkTrace(
SNLIClassifier(Config()).to(device),
(premise, hypothesis),
inputs_require_grads=False,
export_import=check_export_import,
)
@slowTest
def test_snli(self):
self._test_snli(self, device='cpu')
self._test_snli(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_snli(self, device='cuda', check_export_import=False)
self._test_snli(self, device="cuda", check_export_import=False)
@staticmethod
def _test_super_resolution(self, device, check_export_import=True):
class Net(nn.Module):
def __init__(self, upscale_factor):
super().__init__()
@ -414,7 +463,7 @@ class TestModels(JitTestCase):
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
@ -425,17 +474,20 @@ class TestModels(JitTestCase):
return x
net = Net(upscale_factor=4).to(device)
self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
export_import=check_export_import)
self.checkTrace(
net,
(torch.rand(5, 1, 32, 32, device=device),),
export_import=check_export_import,
)
@slowTest
def test_super_resolution(self):
self._test_super_resolution(self, device='cpu')
self._test_super_resolution(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, 'no CUDA')
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_super_resolution_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_super_resolution(self, device='cuda', check_export_import=False)
self._test_super_resolution(self, device="cuda", check_export_import=False)
@suppress_warnings
def test_time_sequence_prediction(self):
@ -485,8 +537,7 @@ class TestModels(JitTestCase):
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
with torch._jit_internal._disable_emit_hooks():
# TODO: toggle export_import once above issues are fixed
self.checkTrace(Traced(), (torch.rand(3, 4),),
export_import=False)
self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
@staticmethod
def _test_vae(self, device, check_export_import=True):
@ -523,22 +574,27 @@ class TestModels(JitTestCase):
with enable_profiling_mode_for_profiling_tests():
# eval() is present because randn_like makes this nondeterministic
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import)
self.checkTrace(
VAE().to(device).eval(),
(torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import,
)
def test_vae(self):
self._test_vae(self, device='cpu')
self._test_vae(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_vae(self, device='cuda', check_export_import=False)
self._test_vae(self, device="cuda", check_export_import=False)
@slowTest
@skipIfNoTorchVision
def test_script_module_trace_resnet18(self):
x = torch.ones(1, 3, 224, 224)
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
m_orig = torch.jit.trace(
torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
)
m_import = self.getExportImportCopy(m_orig)
input = torch.randn(1, 3, 224, 224, requires_grad=True)
@ -559,16 +615,24 @@ class TestModels(JitTestCase):
def test_script_module_script_resnet(self):
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
class BasicBlock(torch.jit.ScriptModule):
expansion = 1
__constants__ = ['downsample']
__constants__ = ["downsample"]
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
@ -600,13 +664,14 @@ class TestModels(JitTestCase):
return out
class ResNet(torch.jit.ScriptModule):
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
__constants__ = ["layer1", "layer2", "layer3", "layer4"]
def __init__(self, block, layers, num_classes=1000):
super().__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -619,7 +684,9 @@ class TestModels(JitTestCase):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@ -679,8 +746,10 @@ class TestModels(JitTestCase):
x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet()
with torch.random.fork_rng(devices=[]):
g, outputs, inputs = torch.jit._get_trace_graph(model, x, return_inputs=True)
self.run_pass('cse', g)
g, outputs, inputs = torch.jit._get_trace_graph(
model, x, return_inputs=True
)
self.run_pass("cse", g)
m = self.createFunctionFromGraph(g)
with torch.random.fork_rng(devices=[]):
self.assertEqual(outputs, m(*inputs))

View File

@ -1,19 +1,23 @@
# Owner(s): ["oncall: jit"]
import torch
import os
import sys
from typing import Any, Dict, List
import torch
from torch.testing._internal.jit_utils import JitTestCase
from typing import Dict, Any, List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleAPIs(JitTestCase):
def test_default_state_dict_methods(self):
@ -52,18 +56,23 @@ class TestModuleAPIs(JitTestCase):
return x
@torch.jit.export
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor],
prefix: str, keep_vars: bool):
def _save_to_state_dict(
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)}
@torch.jit.export
def _load_from_state_dict(self,
state_dict: Dict[str, torch.Tensor],
prefix: str, local_metadata: Any,
strict: bool, missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]):
def _load_from_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
prefix: str,
local_metadata: Any,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True
return
@ -94,18 +103,23 @@ class TestModuleAPIs(JitTestCase):
return x
@torch.jit.export
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor],
prefix: str, keep_vars: bool):
def _save_to_state_dict(
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)}
@torch.jit.export
def _load_from_state_dict(self,
state_dict: Dict[str, torch.Tensor],
prefix: str, local_metadata: Any,
strict: bool, missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]):
def _load_from_state_dict(
self,
state_dict: Dict[str, torch.Tensor],
prefix: str,
local_metadata: Any,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True
return

View File

@ -2,9 +2,10 @@
import os
import sys
from collections import OrderedDict
from typing import Any, List, Tuple
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.testing._internal.jit_utils import JitTestCase
@ -13,10 +14,13 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleContainers(JitTestCase):
def test_sequential_intermediary_types(self):
@ -54,11 +58,13 @@ class TestModuleContainers(JitTestCase):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
modules = OrderedDict([
('one', Inner()),
('two', Inner2()),
('three', Inner3()),
])
modules = OrderedDict(
[
("one", Inner()),
("two", Inner2()),
("three", Inner3()),
]
)
self.moduledict = nn.ModuleDict(modules)
def forward(self, x, skip_name):
@ -115,7 +121,6 @@ class TestModuleContainers(JitTestCase):
return x, x2, names, iter
for name in ["", "one", "two", "three"]:
inp = torch.tensor(1)
self.checkModule(M(), (inp, name))
@ -136,7 +141,7 @@ class TestModuleContainers(JitTestCase):
x = mod(x)
return x - 5
self.checkModule(CustomSequential(), (torch.tensor(.5),))
self.checkModule(CustomSequential(), (torch.tensor(0.5),))
class CustomModuleList(nn.ModuleList):
def __init__(self):
@ -148,16 +153,19 @@ class TestModuleContainers(JitTestCase):
x = mod(x)
return x - 5
self.checkModule(CustomModuleList(), (torch.tensor(.5),))
self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
class CustomModuleDict(nn.ModuleDict):
def __init__(self):
super().__init__(
OrderedDict([
('one', Inner()),
('two', nn.ReLU()),
('three', Inner()),
]))
OrderedDict(
[
("one", Inner()),
("two", nn.ReLU()),
("three", Inner()),
]
)
)
def forward(self, x):
x = x + 3
@ -167,7 +175,7 @@ class TestModuleContainers(JitTestCase):
names.append(name)
return names, x - 5
self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule):
@ -225,7 +233,9 @@ class TestModuleContainers(JitTestCase):
def forward(self, v):
return self.mods[-11].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Index -11 out of range", "self.mods[-11]"):
with self.assertRaisesRegexWithHighlight(
Exception, "Index -11 out of range", "self.mods[-11]"
):
torch.jit.script(M2())
class M3(M):
@ -233,7 +243,9 @@ class TestModuleContainers(JitTestCase):
i = 3
return self.mods[i].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Enumeration is supported", "self.mods[i]"):
with self.assertRaisesRegexWithHighlight(
Exception, "Enumeration is supported", "self.mods[i]"
):
torch.jit.script(M3())
class M4(M):
@ -273,17 +285,23 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs):
assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
assert (
self.modulelist[0] is self.submod
), "__getitem__ failing for ModuleList"
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList"
assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
assert (
self.sequential[0] is self.submod
), "__getitem__ failing for Sequential"
assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential"
assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
assert (
self.moduledict["submod"] is self.submod
), "__getitem__ failing for ModuleDict"
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently
@ -345,12 +363,13 @@ class TestModuleContainers(JitTestCase):
super().__init__()
self.relu = torch.jit.script(torch.nn.ReLU())
self.tanh = torch.jit.script(torch.nn.Tanh())
self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
"tanh": self.tanh})
self.moduledict = torch.nn.ModuleDict(
{"relu": self.relu, "tanh": self.tanh}
)
def forward(self, input):
assert self.moduledict['relu'] is self.relu
assert self.moduledict['tanh'] is self.tanh
assert self.moduledict["relu"] is self.relu
assert self.moduledict["tanh"] is self.tanh
return input
m = MyModule()
@ -360,31 +379,34 @@ class TestModuleContainers(JitTestCase):
class BadModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None,
"bar": None})
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
def forward(self, input):
assert self.moduledict['blah'] == "blah", "this is a keyerror"
assert self.moduledict["blah"] == "blah", "this is a keyerror"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Key Error, blah", "self.moduledict['blah'"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
):
b = BadModule()
torch.jit.script(b)
class AnotherBadModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None,
"bar": None})
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
def forward(self, input):
idx = 'blah'
idx = "blah"
assert self.moduledict[idx] == "blah", "this is a string literal error"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Unable to extract string literal index. "
"ModuleDict indexing is only supported with string literals. "
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
"because i is not a literal.",
"self.moduledict[idx]"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Unable to extract string literal index. "
"ModuleDict indexing is only supported with string literals. "
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
"because i is not a literal.",
"self.moduledict[idx]",
):
b = AnotherBadModule()
torch.jit.script(b)
@ -393,6 +415,7 @@ class TestModuleContainers(JitTestCase):
Test that an attempt to script a module with a regular list attribute
containing other modules fails with a relevant error message.
"""
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
@ -422,7 +445,9 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict()
def forward(self, inputs):
assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
assert (
"submod" not in self.moduledict
), "__contains__ fails for ModuleDict"
return inputs
m = MyModule()
@ -433,6 +458,7 @@ class TestModuleContainers(JitTestCase):
Test that a type annotation can be provided for a ModuleDict that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
@ -485,7 +511,9 @@ class TestModuleContainers(JitTestCase):
submodule: ModuleInterface = self.d[key]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
):
torch.jit.script(ModWithWrongAnnotation())
def test_typed_module_list(self):
@ -493,6 +521,7 @@ class TestModuleContainers(JitTestCase):
Test that a type annotation can be provided for a ModuleList that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
@ -545,7 +574,9 @@ class TestModuleContainers(JitTestCase):
submodule: ModuleInterface = self.l[idx]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
):
torch.jit.script(ModWithWrongAnnotation())
def test_module_properties(self):
@ -596,10 +627,34 @@ class TestModuleContainers(JitTestCase):
def attr(self):
return self.a + 1
self.checkModule(ModuleWithProperties(5), (5, 6,))
self.checkModule(ModuleWithProperties(5), (-5, -6,))
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
self.checkModule(
ModuleWithProperties(5),
(
5,
6,
),
)
self.checkModule(
ModuleWithProperties(5),
(
-5,
-6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
5,
6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
-5,
-6,
),
)
mod = ModuleWithProperties(3)
scripted_mod = torch.jit.script(mod)
@ -625,7 +680,6 @@ class TestModuleContainers(JitTestCase):
def forward(self, x):
return self.linear(self.linear(x))
class N(nn.Module):
def __init__(self):
super().__init__()
@ -659,7 +713,9 @@ class TestModuleContainers(JitTestCase):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x):
self.module_list[0]
@ -673,7 +729,9 @@ class TestModuleContainers(JitTestCase):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x):
r = x
@ -687,9 +745,14 @@ class TestModuleContainers(JitTestCase):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.parameter_dict = nn.ParameterDict({k: nn.Parameter(torch.zeros(1)) for k in ['a', 'b', 'c']})
self.parameter_dict = nn.ParameterDict(
{k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
)
def forward(self, x):
return self.parameter_dict['a'] * x + self.parameter_dict['b'] * self.parameter_dict['c']
return (
self.parameter_dict["a"] * x
+ self.parameter_dict["b"] * self.parameter_dict["c"]
)
self.checkModule(MyModule(), (torch.ones(1),))

View File

@ -1,10 +1,11 @@
# Owner(s): ["oncall: jit"]
from typing import List, Any
import torch
import torch.nn as nn
import os
import sys
from typing import Any, List
import torch
import torch.nn as nn
from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -12,10 +13,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class OrigModule(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
@ -27,6 +31,7 @@ class OrigModule(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewModule(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
@ -34,6 +39,7 @@ class NewModule(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self):
@torch.jit.interface
@ -42,7 +48,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestNotModuleInterfaceCall(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
@ -51,7 +57,9 @@ class TestModuleInterface(JitTestCase):
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.two(input)
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.proxy_mod.two"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", "self.proxy_mod.two"
):
torch.jit.script(TestNotModuleInterfaceCall())
def test_module_interface(self):
@ -108,17 +116,37 @@ class TestModuleInterface(JitTestCase):
scripted_foo_mod = torch.jit.script(FooMod())
scripted_bar_mod = torch.jit.script(BarMod())
self.checkScript(use_module_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
self.checkScript(use_class_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
self.checkScript(
use_module_interface,
(
[scripted_foo_mod, scripted_bar_mod],
torch.rand(3, 4),
),
)
self.checkScript(
use_class_interface,
(
[scripted_foo_mod, scripted_bar_mod],
torch.rand(3, 4),
),
)
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
def call_module_interface_on_other_method(
mod_interface: OneTwoModule, x: Tensor
) -> Tensor:
return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified.
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "mod_interface.forward2"):
self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),))
with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", "mod_interface.forward2"
):
self.checkScript(
call_module_interface_on_other_method,
(
scripted_bar_mod,
torch.rand(3, 4),
),
)
def test_module_doc_string(self):
@torch.jit.interface
@ -135,7 +163,7 @@ class TestModuleInterface(JitTestCase):
r"""stuff 3"""
class TestModule(nn.Module):
proxy_mod : TestInterface
proxy_mod: TestInterface
def __init__(self):
super().__init__()
@ -178,7 +206,9 @@ class TestModuleInterface(JitTestCase):
return self.one(self.two(x), x)
# check class object is not a subtype of module interface
with self.assertRaisesRegex(RuntimeError, "ScriptModule class can be subtype of module interface"):
with self.assertRaisesRegex(
RuntimeError, "ScriptModule class can be subtype of module interface"
):
as_module_interface(Foo())
class WrongMod(nn.Module):
@ -233,9 +263,11 @@ class TestModuleInterface(JitTestCase):
as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
as_any_to_any(torch.jit.script(AnyToAnyImpl()))
def test_module_interface_inheritance(self):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
with self.assertRaisesRegex(
RuntimeError, "does not support inheritance yet. Please directly"
):
@torch.jit.interface
class InheritMod(nn.ReLU):
def three(self, x: Tensor) -> Tensor:
@ -251,7 +283,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
@ -269,7 +301,9 @@ class TestModuleInterface(JitTestCase):
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# module swap with non-scripted module should throw error
with self.assertRaisesRegex(RuntimeError, "a ScriptModule with non-scripted module"):
with self.assertRaisesRegex(
RuntimeError, "a ScriptModule with non-scripted module"
):
scripted_mod.proxy_mod = NewModule()
def test_module_swap_wrong_module(self):
@ -286,7 +320,7 @@ class TestModuleInterface(JitTestCase):
return input + 1
class TestModule(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
@ -310,7 +344,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
@ -358,9 +392,11 @@ class TestModuleInterface(JitTestCase):
# proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
# proxy_mod is neither a module interface or have the same JIT type, should fail
with self.assertRaisesRegex(RuntimeError,
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"):
with self.assertRaisesRegex(
RuntimeError,
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' "
+ r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'",
):
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
def test_script_module_as_interface_swap(self):
@ -391,7 +427,7 @@ class TestModuleInterface(JitTestCase):
return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()
@ -432,7 +468,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
proxy_mod: ModInterface
def __init__(self):
super().__init__()
@ -480,7 +516,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
proxy_mod: ModInterface
def __init__(self):
super().__init__()
@ -523,7 +559,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
proxy_mod: ModInterface
def __init__(self):
super().__init__()
@ -568,7 +604,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
proxy_mod: ModInterface
def __init__(self):
super().__init__()
@ -583,7 +619,9 @@ class TestModuleInterface(JitTestCase):
m = torch.jit.script(TestModule())
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type."):
with self.assertRaisesRegex(
RuntimeError, "Freezing does not support SetAttr on an interface type."
):
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_freeze_module_with_interface_and_fork(self):
@ -610,7 +648,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
proxy_mod: ModInterface
def __init__(self):
super().__init__()
@ -644,7 +682,7 @@ class TestModuleInterface(JitTestCase):
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
proxy_mod: ModuleInterface
def __init__(self):
super().__init__()

View File

@ -1,18 +1,22 @@
# Owner(s): ["oncall: jit"]
import torch
import os
import sys
import torch
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModules(JitTestCase):
def test_script_module_with_constants_list(self):

View File

@ -4,10 +4,13 @@ import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestOpDecompositions(JitTestCase):
def test_op_decomposition(self):
@ -31,7 +34,9 @@ class TestOpDecompositions(JitTestCase):
def square_decomp(x):
return torch.pow(x, 2)
torch.jit._register_decomposition(torch.ops.aten.square.default, square_decomp.graph)
torch.jit._register_decomposition(
torch.ops.aten.square.default, square_decomp.graph
)
torch._C._jit_pass_run_decompositions(foo.graph)
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
x = torch.rand([4])

View File

@ -3,8 +3,9 @@
import torch
import torch._C
import torch.nn.functional as F
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoXNNPACK
from torch.testing._internal.jit_utils import JitTestCase
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
def check_replacement(
@ -133,10 +134,8 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
"prepacked::linear_clamp_run": "aten::linear",
"prepacked::conv2d_clamp_prepack": "aten::conv2d",
"prepacked::conv2d_clamp_run": "aten::conv2d",
"prepacked::conv2d_transpose_clamp_prepack":
"aten::conv_transpose2d",
"prepacked::conv2d_transpose_clamp_run":
"aten::conv_transpose2d",
"prepacked::conv2d_transpose_clamp_prepack": "aten::conv_transpose2d",
"prepacked::conv2d_transpose_clamp_run": "aten::conv_transpose2d",
},
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
)
@ -147,7 +146,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
replacements={
"prepacked::linear_clamp_prepack": "aten::linear",
"prepacked::linear_clamp_run": "aten::linear"
"prepacked::linear_clamp_run": "aten::linear",
},
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
)
@ -223,11 +222,9 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
self.check_replacement(
model=model,
replacements={
"prepacked::linear_clamp_prepack":
"prepacked::linear_clamp_prepack",
"prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack",
"prepacked::linear_clamp_run": linear_activation_kind,
"prepacked::conv2d_clamp_prepack":
"prepacked::conv2d_clamp_prepack",
"prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack",
"prepacked::conv2d_clamp_run": conv2d_activation_kind,
},
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
@ -239,7 +236,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.hardtanh,
linear_activation_kind="aten::hardtanh",
conv2d_activation=F.hardtanh_,
conv2d_activation_kind="aten::hardtanh_"
conv2d_activation_kind="aten::hardtanh_",
)
@skipIfNoXNNPACK
@ -248,7 +245,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.hardtanh_,
linear_activation_kind="aten::hardtanh_",
conv2d_activation=F.hardtanh,
conv2d_activation_kind="aten::hardtanh"
conv2d_activation_kind="aten::hardtanh",
)
@skipIfNoXNNPACK
@ -257,7 +254,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.relu,
linear_activation_kind="aten::relu",
conv2d_activation=F.relu_,
conv2d_activation_kind="aten::relu_"
conv2d_activation_kind="aten::relu_",
)
@skipIfNoXNNPACK
@ -266,5 +263,5 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.relu_,
linear_activation_kind="aten::relu_",
conv2d_activation=F.relu,
conv2d_activation_kind="aten::relu"
conv2d_activation_kind="aten::relu",
)

View File

@ -2,16 +2,19 @@
import torch
from torch import nn
import torch.nn.utils.parametrize as parametrize
from torch import nn
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestParametrization(JitTestCase):
# Define some parametrization
@ -29,7 +32,7 @@ class TestParametrization(JitTestCase):
# Check the tracing works. Because traced functions cannot be called
# directly, we run the comparison on the activations.
traced_model = torch.jit.trace_module(model, {'forward': x})
traced_model = torch.jit.trace_module(model, {"forward": x})
y_hat = traced_model(x)
self.assertEqual(y, y_hat)
@ -39,10 +42,9 @@ class TestParametrization(JitTestCase):
self.assertEqual(y, y_hat)
# Check the tracing throws an error when caching
with self.assertRaisesRegex(RuntimeError,
'Cannot trace a model while caching'):
with self.assertRaisesRegex(RuntimeError, "Cannot trace a model while caching"):
with parametrize.cached():
traced_model = torch.jit.trace_module(model, {'forward': x})
traced_model = torch.jit.trace_module(model, {"forward": x})
def test_scriptable(self):
# TODO: Need to fix the scripting in parametrizations
@ -65,5 +67,5 @@ class TestParametrization(JitTestCase):
self.assertEqual(y, y_hat)
# Check the scripting process throws an error when caching
with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'):
with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
scripted_model = torch.jit.trace_module(model)

View File

@ -2,18 +2,22 @@
import os
import sys
from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
import torch
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
from typing import List, Dict, Tuple, Any, Optional, NamedTuple # noqa: F401
from torch.testing._internal.common_utils import NoTest
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if not _IS_MONKEYTYPE_INSTALLED:
print("monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr)
print(
"monkeytype is not installed. Skipping tests for Profile-Directed Typing",
file=sys.stderr,
)
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
if __name__ == "__main__":
@ -23,10 +27,12 @@ if __name__ == "__main__":
"instead."
)
class TestPDT(JitTestCase):
"""
A suite of tests for profile directed typing in TorchScript.
"""
def test_nn_module(self):
class TestPDTModel(torch.nn.Module):
def forward(self, x) -> Any:
@ -39,8 +45,14 @@ class TestPDT(JitTestCase):
make_global(TestPDTModel)
pdt_model = TestPDTModel()
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
inp: List[Tuple[Any, ...]] = [
(20,),
(2.7,),
(False,),
]
scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
@ -63,8 +75,10 @@ class TestPDT(JitTestCase):
make_global(NestedPDTInner, NestedModulePDTWrapper)
inner_pdt_model = NestedPDTInner()
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
inp: List[Tuple[Any, ...]] = [(20,), (False,)]
scripted_pdt_model = torch.jit.script(
wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
@ -87,10 +101,18 @@ class TestPDT(JitTestCase):
make_global(NestedModulePDTInner, NestedModulePDTOuter)
inner_pdt_model = NestedModulePDTInner()
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
outer_pdt_model: outer_input, })
inner_input: List[Tuple[Any, ...]] = [
(10, 10),
(1.9, 20),
]
outer_input: List[Tuple[Any, ...]] = [(20,), (False,)]
scripted_pdt_model = torch.jit.script(
outer_pdt_model,
example_inputs={
inner_pdt_model: inner_input,
outer_pdt_model: outer_input,
},
)
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
@ -109,8 +131,10 @@ class TestPDT(JitTestCase):
make_global(NestedFunctionInForward)
pdt_model = NestedFunctionInForward()
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
inp: List[Tuple[Any, ...]] = [(-1,), (False,)]
scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
self.assertEqual(scripted_pdt_model(True), pdt_model(True))
@ -126,14 +150,26 @@ class TestPDT(JitTestCase):
else:
return -1
make_global(TestModelWithExport)
pdt_model = TestModelWithExport()
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
inp: List[Tuple[Any, ...]] = [
(
20,
10,
),
(
2.7,
8.9,
),
]
scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model.fn: inp}
)
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
self.assertTrue(
scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)
)
def test_class_methods(self):
class PDTModel:
@ -142,10 +178,34 @@ class TestPDT(JitTestCase):
make_global(PDTModel)
pdt_model = PDTModel()
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
inp: List[Tuple[Any, ...]] = [
(
[
10,
20,
],
),
]
scripted_pdt_model = torch.jit.script(
PDTModel, example_inputs={pdt_model.test_sum: inp}
)
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
self.assertEqual(
script_model.test_sum(
[
10,
20,
30,
],
),
pdt_model.test_sum(
[
10,
20,
30,
],
),
)
def test_class_with_multiple_methods(self):
class PDTModelWithManyMethods:
@ -160,14 +220,64 @@ class TestPDT(JitTestCase):
make_global(PDTModelWithManyMethods)
pdt_model = PDTModelWithManyMethods()
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp})
list_inp: List[Tuple[Any, ...]] = [
(
[
1.2,
2.3,
],
),
]
str_inp: List[Tuple[Any, ...]] = [
(
"abc",
"b",
),
]
scripted_pdt_model = torch.jit.script(
PDTModelWithManyMethods,
example_inputs={
pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp,
},
)
script_model = scripted_pdt_model()
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", ))
self.assertEqual(script_model.test_substring("helloworld", "def", ), pdt_model.test_substring("helloworld", "def", ))
self.assertEqual(
script_model.test_list_to_dict(
[
1.1,
2.2,
3.3,
],
),
pdt_model.test_list_to_dict(
[
1.1,
2.2,
3.3,
],
),
)
self.assertEqual(
script_model.test_substring(
"helloworld",
"world",
),
pdt_model.test_substring(
"helloworld",
"world",
),
)
self.assertEqual(
script_model.test_substring(
"helloworld",
"def",
),
pdt_model.test_substring(
"helloworld",
"def",
),
)
def test_multiple_class_with_same_method(self):
class PDTModelOne:
@ -181,16 +291,69 @@ class TestPDT(JitTestCase):
make_global(PDTModelOne, PDTModelTwo)
pdt_model_one = PDTModelOne()
pdt_model_two = PDTModelTwo()
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
dict_inp: List[Tuple[Any, ...]] = [
(
{
1.2: True,
2.3: False,
},
1.2,
),
]
list_inp: List[Tuple[Any, ...]] = [
(
[
"abc",
"b",
],
"c",
),
]
scripted_pdt_model_one = torch.jit.script(
PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}
)
scripted_pdt_model_two = torch.jit.script(
PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}
)
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
pdt_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4))
self.assertEqual(script_model_two.test_find(["hello", "world", ], "world"),
pdt_model_two.test_find(["hello", "world", ], "world"))
script_model_one, script_model_two = (
scripted_pdt_model_one(),
scripted_pdt_model_two(),
)
self.assertEqual(
script_model_one.test_find(
{
1.1: True,
2.2: True,
3.3: False,
},
4.4,
),
pdt_model_one.test_find(
{
1.1: True,
2.2: True,
3.3: False,
},
4.4,
),
)
self.assertEqual(
script_model_two.test_find(
[
"hello",
"world",
],
"world",
),
pdt_model_two.test_find(
[
"hello",
"world",
],
"world",
),
)
def test_pdt(self):
def test_sum(a, b):
@ -218,7 +381,9 @@ class TestPDT(JitTestCase):
return torch.complex(real, img)
make_global(test_args_complex)
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
scripted_fn_complex = torch.jit.script(
test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]
)
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
@ -248,25 +413,49 @@ class TestPDT(JitTestCase):
make_global(test_list_and_tuple)
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
scripted_fn_float_list_input = torch.jit.script(
test_list_and_tuple, example_inputs=[([4.9, 8.9],)]
)
self.assertEqual(
scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])
)
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
scripted_fn_bool_list_input = torch.jit.script(
test_list_and_tuple, example_inputs=[([True, False, True],)]
)
self.assertEqual(
scripted_fn_bool_list_input([True, True, True]),
test_list_and_tuple([True, True, True]),
)
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
scripted_fn_int_list_input = torch.jit.script(
test_list_and_tuple, example_inputs=[([3, 4, 5],)]
)
self.assertEqual(
scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])
)
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
scripted_fn_float_tuple_input = torch.jit.script(
test_list_and_tuple, example_inputs=[((4.9, 8.9),)]
)
self.assertEqual(
scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))
)
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
example_inputs=[((True, False, True),)])
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
test_list_and_tuple((True, True, True)))
scripted_fn_bool_tuple_input = torch.jit.script(
test_list_and_tuple, example_inputs=[((True, False, True),)]
)
self.assertEqual(
scripted_fn_bool_tuple_input((True, True, True)),
test_list_and_tuple((True, True, True)),
)
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
scripted_fn_int_tuple_input = torch.jit.script(
test_list_and_tuple, example_inputs=[((3, 4, 5),)]
)
self.assertEqual(
scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))
)
def test_nested_list_and_tuple(self):
def test_nested_list(inp):
@ -282,43 +471,207 @@ class TestPDT(JitTestCase):
make_global(test_nested_list, test_nested_tuple)
list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
list_inp = [
[
1,
2,
3,
],
[
5,
6,
7,
],
]
scripted_fn = torch.jit.script(
test_nested_list,
example_inputs=[
(list_inp,),
],
)
inp = [
[
0,
4,
7,
],
[
8,
11,
],
[
6,
-1,
-20,
],
]
self.assertEqual(
scripted_fn(
inp,
),
test_nested_list(
inp,
),
)
list_inp = ([1, 2, 3, ], [5, 6, 7, ])
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
list_inp = (
[
1,
2,
3,
],
[
5,
6,
7,
],
)
scripted_fn = torch.jit.script(
test_nested_list,
example_inputs=[
(list_inp,),
],
)
inp = (
[
0,
4,
7,
],
[
8,
11,
],
[
6,
-1,
-20,
],
)
self.assertEqual(
scripted_fn(
inp,
),
test_nested_list(
inp,
),
)
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
tup_inp = [
(
1.0,
2.6,
3.7,
),
(
5.7,
6.1,
1.7,
),
]
scripted_fn = torch.jit.script(
test_nested_tuple,
example_inputs=[
(tup_inp,),
],
)
inp = [
(
1.0,
4.1,
7.4,
),
(
4.8,
1.1,
-1.2,
),
(
6.3,
-1.3,
-2.0,
),
]
self.assertEqual(
scripted_fn(
inp,
),
test_nested_tuple(
inp,
),
)
tup_inp = ((True, False, True, ), (False, False, False, ))
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
inp = ((True, True, True, ), (False, False, True, ))
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
tup_inp = (
(
True,
False,
True,
),
(
False,
False,
False,
),
)
scripted_fn = torch.jit.script(
test_nested_tuple,
example_inputs=[
(tup_inp,),
],
)
inp = (
(
True,
True,
True,
),
(
False,
False,
True,
),
)
self.assertEqual(
scripted_fn(
inp,
),
test_nested_tuple(
inp,
),
)
def test_pdt_dict(self):
def test_dict(a):
return a['foo']
return a["foo"]
def test_dict_int_list(a):
return a[1]
make_global(test_dict, test_dict_int_list)
str_bool_inp = {'foo' : True, 'bar': False}
str_bool_inp = {"foo": True, "bar": False}
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
self.assertEqual(
scripted_fn(
{"foo": False, "bar": True},
),
test_dict(
{"foo": False, "bar": True},
),
)
str_list_inp = {0 : [True, False], 1: [False, True]}
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
str_list_inp = {0: [True, False], 1: [False, True]}
scripted_fn = torch.jit.script(
test_dict_int_list, example_inputs=[(str_list_inp,)]
)
self.assertEqual(
scripted_fn(
{0: [False, False], 1: [True, True]},
),
test_dict_int_list(
{0: [False, False], 1: [True, True]},
),
)
def test_any(self):
def test_multiple_types(a):
@ -337,20 +690,36 @@ class TestPDT(JitTestCase):
make_global(test_multiple_types, test_multiple_type_refinement)
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
scripted_fn = torch.jit.script(
test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)]
)
self.assertEqual(scripted_fn(10), test_multiple_types(10))
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
([3, 4, 5],), (True, ), ({"a": True}, ), ])
scripted_fn = torch.jit.script(
test_multiple_type_refinement,
example_inputs=[
(1,),
("abc",),
(8.9,),
([3, 4, 5],),
(True,),
({"a": True},),
],
)
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14]))
self.assertEqual(
scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])
)
self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False))
self.assertEqual(scripted_fn({"abc" : True, "def": False}), test_multiple_type_refinement({"abc" : True, "def": False}))
self.assertEqual(
scripted_fn({"abc": True, "def": False}),
test_multiple_type_refinement({"abc": True, "def": False}),
)
def test_class_as_profiled_types(self):
class UserDefinedClass:
@ -369,9 +738,33 @@ class TestPDT(JitTestCase):
make_global(UserDefinedClass, test_model)
user_class = UserDefinedClass()
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
scripted_fn = torch.jit.script(
test_model,
example_inputs=[
(
10,
user_class,
),
(
10.9,
user_class,
),
],
)
self.assertEqual(
scripted_fn(
100,
user_class,
),
test_model(100, user_class),
)
self.assertEqual(
scripted_fn(
1.9,
user_class,
),
test_model(1.9, user_class),
)
def test_class_with_args_as_profiled_types(self):
class ClassWithArgs:
@ -391,8 +784,26 @@ class TestPDT(JitTestCase):
make_global(ClassWithArgs, test_model_with_args)
user_class = ClassWithArgs(False)
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
scripted_fn = torch.jit.script(
test_model_with_args,
example_inputs=[
(
10,
user_class,
),
(
10.9,
user_class,
),
],
)
self.assertEqual(
scripted_fn(
100,
ClassWithArgs(True),
),
test_model_with_args(100, ClassWithArgs(True)),
)
def test_nn_parameter_as_arg(self):
class TestNNParameter(torch.nn.Module):
@ -408,7 +819,14 @@ class TestPDT(JitTestCase):
make_global(TestNNParameter)
pdt_model = TestNNParameter()
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
scripted_fn = torch.jit.script(
pdt_model,
example_inputs={
pdt_model: [
(10,),
],
},
)
self.assertEqual(scripted_fn(20), pdt_model(20))
def test_fx_tracing_with_typing(self):
@ -422,7 +840,19 @@ class TestPDT(JitTestCase):
make_global(FXModel, FXModelOutput)
pdt_model = FXModel()
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
scripted_fn = torch.jit.script(
pdt_model,
example_inputs={
pdt_model: [
(
[
10,
20,
],
),
],
},
)
self.assertEqual(scripted_fn([20]), pdt_model([20]))
def test_nonetype_as_optional_of_type(self):
@ -434,11 +864,34 @@ class TestPDT(JitTestCase):
make_global(test_none)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)])
self.assertEqual(
scripted_fn(
30.9,
),
test_none(
30.9,
),
)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
self.assertEqual(scripted_fn(2, ), test_none(2, ))
scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)])
self.assertEqual(
scripted_fn(
2,
),
test_none(
2,
),
)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
scripted_fn = torch.jit.script(
test_none, example_inputs=[(None,), (torch.Tensor(1),)]
)
self.assertEqual(
scripted_fn(
torch.ones(1),
),
test_none(
torch.ones(1),
),
)

View File

@ -1,17 +1,20 @@
# Owner(s): ["oncall: jit"]
import torch
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
from torch import nn
from torch.testing import FileCheck
import unittest
from typing import Callable, List
import unittest
import torch
from torch import nn
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestPeephole(JitTestCase):
def test_peephole_with_writes(self):
@ -62,11 +65,11 @@ class TestPeephole(JitTestCase):
tf = torch.jit.trace(f, (a, b))
FileCheck().check("type_as").run(str(tf.graph))
self.run_pass('peephole', tf.graph)
self.run_pass("peephole", tf.graph)
FileCheck().check_not("type_as").run(str(tf.graph))
tf2 = torch.jit.trace(f, (a, c))
s = str(tf2.graph)
self.run_pass('peephole', tf2.graph)
self.run_pass("peephole", tf2.graph)
self.assertEqual(s, str(s))
def test_peephole_dynamic(self):
@ -83,7 +86,7 @@ class TestPeephole(JitTestCase):
def foo(x, y, z):
return len([x, y, z])
self.run_pass('peephole', foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check("value=3").check_next("return").run(foo.graph)
@torch.jit.script
@ -93,7 +96,7 @@ class TestPeephole(JitTestCase):
li.append(x)
return len([x, y, z])
self.run_pass('peephole', foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::len").run(foo.graph)
@torch.jit.script
@ -102,7 +105,7 @@ class TestPeephole(JitTestCase):
return li[1], li[-2]
FileCheck().check("aten::__getitem__").run(foo.graph)
self.run_pass('peephole', foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::__getitem__").run(foo.graph)
@torch.jit.script
@ -110,7 +113,7 @@ class TestPeephole(JitTestCase):
li = [x, y, z]
return li[-7]
self.run_pass('peephole', foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph)
@torch.jit.script
@ -120,25 +123,25 @@ class TestPeephole(JitTestCase):
li.append(x)
return li[-2]
self.run_pass('peephole', foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph)
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
def test_peephole_cuda(self):
a = torch.tensor([0.4], device='cpu')
b = torch.tensor([0.7], device='cuda')
c = torch.tensor([0.7], device='cuda')
a = torch.tensor([0.4], device="cpu")
b = torch.tensor([0.7], device="cuda")
c = torch.tensor([0.7], device="cuda")
def f(x, y):
return x.type_as(y)
trace = torch.jit.trace(f, (a, c))
s = str(trace.graph)
self.run_pass('peephole', trace.graph)
self.run_pass("peephole", trace.graph)
self.assertEqual(s, str(trace.graph))
trace = torch.jit.trace(f, (b, c))
self.run_pass('peephole', trace.graph)
self.run_pass('dce', trace.graph)
self.run_pass("peephole", trace.graph)
self.run_pass("dce", trace.graph)
FileCheck().check_not("type_as").run(str(trace.graph))
@_inline_everything
@ -152,7 +155,7 @@ class TestPeephole(JitTestCase):
return refine(torch.tensor(4))
FileCheck().check("prim::unchecked_cast").run(test.graph)
self.run_pass('peephole', test.graph)
self.run_pass("peephole", test.graph)
FileCheck().check_not("prim::unchecked_cast").run(test.graph)
# refinement not optimzied out
@ -166,7 +169,7 @@ class TestPeephole(JitTestCase):
self.checkScript(is_int_tensor, (torch.tensor(2),))
self.checkScript(is_int_tensor, (torch.tensor(2.5),))
graph = torch.jit.script(is_int_tensor).graph
self.run_pass('peephole', graph)
self.run_pass("peephole", graph)
FileCheck().check("prim::unchecked_cast").run(graph)
def test_short_circuit_optimization(self):
@ -174,8 +177,11 @@ class TestPeephole(JitTestCase):
def const_expressions(x):
# type: (int) -> Tuple[bool, bool]
return x == 1 and False, x == 1 or True
self.run_pass('constant_propagation', const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
self.run_pass("constant_propagation", const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(
const_expressions.graph
)
self.assertEqual(const_expressions(1), (False, True))
@torch.jit.script
@ -183,15 +189,18 @@ class TestPeephole(JitTestCase):
# type: (int) -> Tuple[bool, bool]
return x == 1 and True, x == 1 or False
self.run_pass('peephole', redundant_expressions.graph)
self.run_pass("peephole", redundant_expressions.graph)
self.assertEqual(redundant_expressions(1), (True, True))
self.assertEqual(redundant_expressions(0), (False, False))
# and True / or False are removed from graph
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
FileCheck().check("aten::eq").check_not("prim::If").run(
redundant_expressions.graph
)
def test_conv_dim_folding(self):
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
for mod in modules:
class ConvDim(torch.nn.Module):
def __init__(self):
super().__init__()
@ -233,7 +242,6 @@ class TestPeephole(JitTestCase):
FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph)
FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph)
def test_normalized_is_op(self):
def convertible_is_op(x: bool, y: bool):
return x is True, False is x, x is y
@ -558,7 +566,7 @@ class TestPeephole(JitTestCase):
def foo4():
x = torch.zeros([2, 2])
return x + 0.
return x + 0.0
funcs = foo1, foo2, foo3, foo4
inps = (torch.ones([2]),), (), (), ()
@ -582,7 +590,7 @@ class TestPeephole(JitTestCase):
self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2])))
def func(x):
return (x + 0.) - 5
return (x + 0.0) - 5
func_s = torch.jit.script(func)
inp = next(func_s.graph.inputs())
@ -640,6 +648,7 @@ class TestPeephole(JitTestCase):
return z
else:
return z2
out = next(foo.graph.findNode("prim::If").outputs())
out.setType(torch._C.OptionalType(torch._C.IntType.get()))
self.run_pass("peephole", foo.graph)
@ -665,12 +674,13 @@ class TestPeephole(JitTestCase):
_6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1
return [_5, int(_6)]
FileCheck().check("aten::add").check("aten::sub") \
.check("aten::mul").check("aten::floordiv") \
.check("aten::div").run(foo.graph)
FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check(
"aten::floordiv"
).check("aten::div").run(foo.graph)
self.run_pass("peephole", foo.graph)
FileCheck().check("graph").check("):") \
.check_next("ListConstruct").check_next("return").run(foo.graph)
FileCheck().check("graph").check("):").check_next("ListConstruct").check_next(
"return"
).run(foo.graph)
self.assertEqual(foo(0, 1, 2, 3), [1, 3])
def test_peephole_dict_getitem_simple(self):
@ -687,9 +697,9 @@ class TestPeephole(JitTestCase):
@torch.jit.script
def foo(a: int, b: int):
d = {'0': a, '1': b}
x = d['1']
y = d['0']
d = {"0": a, "1": b}
x = d["1"]
y = d["0"]
return x, y
self.run_pass("peephole", foo.graph)
@ -815,14 +825,14 @@ class TestPeephole(JitTestCase):
graph = torch.jit.script(foo).graph
self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph)
self.checkScript(foo, (3, ))
self.checkScript(foo, (3,))
def test_peephole_slice_one_empty_arg(self):
def check_helper(fn: Callable[[int], None]) -> None:
graph = torch.jit.script(fn).graph
self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph)
self.checkScript(fn, (3, ))
self.checkScript(fn, (3,))
def foo(x: int):
return [1, 2, x, 4, 5, 6, 7][1::2]
@ -844,7 +854,7 @@ class TestPeephole(JitTestCase):
graph = torch.jit.script(fn).graph
self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph)
self.checkScript(fn, (3, ))
self.checkScript(fn, (3,))
def foo(x: int):
return [1, 2, x, 4, 5, 6, 7][::2]

View File

@ -9,12 +9,15 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, warmup_backward, FileCheck
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
@skipIfTorchDynamo()
class TestProfiler(JitTestCase):
@ -58,8 +61,9 @@ class TestProfiler(JitTestCase):
# item & add should not get pulled into the fusion group -
# we expect to see Fusion Group (item / add) Fusion Group in ir dump
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next("Tensor = aten::add").check("TensorExpr").run(g)
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
"Tensor = aten::add"
).check("TensorExpr").run(g)
@torch.jit.script
def non_const_dtype(x, y, cond: bool):
@ -70,7 +74,9 @@ class TestProfiler(JitTestCase):
non_const_dtype(x, x, True)
g = torch.jit.last_executed_optimized_graph()
# because dtype is non-const, sum should not get pulled into the Fusion Group
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(g)
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
g
)
def test_specialize_backward(self):
def test_fuse(a, b):
@ -118,13 +124,15 @@ class TestProfiler(JitTestCase):
d = c * b
return d
x = torch.tensor([.5])
x = torch.tensor([0.5])
for _ in range(3):
test_fuse(x, x)
g = torch.jit.last_executed_optimized_graph()
# Types should remain specialized for typecheck outputs & fusion outputs
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g)
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
"\n"
).check("Double").check_same("TensorExpr").run(g)
# other outputs should not be specialized
FileCheck().check("Tensor = prim::If").run(g)
@ -201,7 +209,9 @@ class TestProfiler(JitTestCase):
foo(x, y)
foo(x, y)
g = torch.jit.last_executed_optimized_graph()
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(g)
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
g
)
def test_autograd_fallback_graph(self):
@torch.jit.script

View File

@ -1,13 +1,13 @@
# Owner(s): ["oncall: jit"]
import os
import random
import sys
import tempfile
import random
from textwrap import dedent
import torch
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -20,14 +20,17 @@ if __name__ == "__main__":
"instead."
)
def get_fn(file_name, script_path):
import importlib.util
spec = importlib.util.spec_from_file_location(file_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
fn = module.fn
return fn
class TestPythonBuiltinOP(JitTestCase):
def test_add(self):
def func(a, b):
@ -48,16 +51,18 @@ class TestPythonBuiltinOP(JitTestCase):
self.checkScript(func, (a, b), optimize=True)
def test_matmul_py3(self):
code = dedent("""
code = dedent(
"""
def fn(a, b):
return a @ b
""")
"""
)
with tempfile.TemporaryDirectory() as tmp_dir:
script_path = os.path.join(tmp_dir, 'script.py')
with open(script_path, 'w') as f:
script_path = os.path.join(tmp_dir, "script.py")
with open(script_path, "w") as f:
f.write(code)
fn = get_fn('test_matmul_py3', script_path)
fn = get_fn("test_matmul_py3", script_path)
a = torch.rand(4, 3, requires_grad=True)
b = torch.rand(3, 2, requires_grad=True)
@ -65,18 +70,18 @@ class TestPythonBuiltinOP(JitTestCase):
def test_pow(self):
def func(a, b):
return a ** b
return a**b
def func2(a, b, c, d):
return c + a ** b ** d
return c + a**b**d
def func3(a, b):
# type: (int, float) -> float
return a ** b
return a**b
def func4():
# type: () -> float
return 2 ** -2
return 2**-2
def func5(x, y):
return x.item() ** y.item()
@ -90,7 +95,12 @@ class TestPythonBuiltinOP(JitTestCase):
self.checkScript(func3, (4, -0.5), optimize=True)
self.checkScript(func4, ())
inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)]
inputs = [
torch.tensor(2),
torch.tensor(-2),
torch.tensor(0.5),
torch.tensor(0.2),
]
for x in inputs:
for y in inputs:
if x < 0:
@ -100,7 +110,7 @@ class TestPythonBuiltinOP(JitTestCase):
def test_triple(self):
def func(x):
return 3. * x
return 3.0 * x
x = torch.rand(1, dtype=torch.float, requires_grad=True)
self.checkScript(func, [x], optimize=True)
@ -154,22 +164,36 @@ class TestPythonBuiltinOP(JitTestCase):
def test_stepped_tuple_slicing(self):
def check_slicing_tuple(slicing, tuple_type, tuple):
template = dedent("""
template = dedent(
"""
def func(x):
# type: ({}) -> Any
return x{}
""")
"""
)
self._check_code(template.format(tuple_type, slicing), "func", [tuple])
check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2))
check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
check_slicing_tuple(
"[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
)
check_slicing_tuple(
"[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
)
check_slicing_tuple(
"[5:7:-2]",
"Tuple[int, int, int, int, int, int, int]",
(0, 1, 2, 3, 4, 5, 6),
)
check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5))
check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple(
"[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)
)
check_slicing_tuple(
"[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)
)
def test_index(self):
def consec(size, start=0):
@ -177,10 +201,12 @@ class TestPythonBuiltinOP(JitTestCase):
return torch.arange(numel).view(size)
def check_indexing(indexing, tensor):
template = dedent("""
template = dedent(
"""
def func(x):
return x{}
""")
"""
)
self._check_code(template.format(indexing), "func", [tensor])
@ -188,62 +214,66 @@ class TestPythonBuiltinOP(JitTestCase):
value1 = torch.tensor(value1)
value2 = torch.tensor(value2)
template = dedent("""
template = dedent(
"""
def func(x, value1, value2):
i = int(value1)
j = int(value2)
return x{}
""")
"""
)
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
self._check_code(
template.format(indexing), "func", [tensor, value1, value2]
)
# basic slices
check_indexing('[0]', consec((3, 3)))
check_indexing('[1]', consec((3, 3), 10))
check_indexing('[2]', consec((3, 3), 19))
check_indexing('[2]', consec((3,)))
check_indexing('[-1]', consec((3, 3), 19))
check_indexing('[0:2]', consec((3, 3, 3)))
check_indexing('[1:-1]', consec((3, 3, 3)))
check_indexing('[-3:-1]', consec((6, 3)))
check_indexing('[1:]', consec((3, 3)))
check_indexing('[:1]', consec((3, 3)))
check_indexing('[:]', consec((3, 2)))
check_indexing("[0]", consec((3, 3)))
check_indexing("[1]", consec((3, 3), 10))
check_indexing("[2]", consec((3, 3), 19))
check_indexing("[2]", consec((3,)))
check_indexing("[-1]", consec((3, 3), 19))
check_indexing("[0:2]", consec((3, 3, 3)))
check_indexing("[1:-1]", consec((3, 3, 3)))
check_indexing("[-3:-1]", consec((6, 3)))
check_indexing("[1:]", consec((3, 3)))
check_indexing("[:1]", consec((3, 3)))
check_indexing("[:]", consec((3, 2)))
# multi-dim: indexes
check_indexing('[0, 1]', consec((3, 3)))
check_indexing('[0, 1]', consec((3, 3, 2)))
check_indexing('[1, 0, 2]', consec((3, 3, 3)))
check_indexing('[2, -1]', consec((3, 3)))
check_indexing("[0, 1]", consec((3, 3)))
check_indexing("[0, 1]", consec((3, 3, 2)))
check_indexing("[1, 0, 2]", consec((3, 3, 3)))
check_indexing("[2, -1]", consec((3, 3)))
# multi-dim: mixed slicing and indexing
check_indexing('[0, 1:2]', consec((3, 3)))
check_indexing('[0, :1]', consec((3, 3, 2)))
check_indexing('[1, 2:]', consec((3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
check_indexing("[0, 1:2]", consec((3, 3)))
check_indexing("[0, :1]", consec((3, 3, 2)))
check_indexing("[1, 2:]", consec((3, 3, 3)))
check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3)))
check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3)))
check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3)))
# zero-sized slices
check_indexing('[0:0]', consec((2, 2)))
check_indexing('[0:0, 1]', consec((3, 3)))
check_indexing("[0:0]", consec((2, 2)))
check_indexing("[0:0, 1]", consec((3, 3)))
# trivial expression usage
check_indexing('[1+1]', consec((3, 3)))
check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
check_indexing("[1+1]", consec((3, 3)))
check_indexing("[1:(0 + 2)]", consec((3, 3, 3)))
# None for new dimensions
check_indexing('[None, 0]', consec((3, 3)))
check_indexing('[1, None]', consec((3, 3), 10))
check_indexing('[None, None, 2]', consec((3, 3), 19))
check_indexing('[None, 2, None]', consec((3,)))
check_indexing('[0:2, None]', consec((3, 3, 3)))
check_indexing('[None, 1:-1]', consec((3, 3, 3)))
check_indexing('[None, -3:-1, None]', consec((6, 3)))
check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
check_indexing("[None, 0]", consec((3, 3)))
check_indexing("[1, None]", consec((3, 3), 10))
check_indexing("[None, None, 2]", consec((3, 3), 19))
check_indexing("[None, 2, None]", consec((3,)))
check_indexing("[0:2, None]", consec((3, 3, 3)))
check_indexing("[None, 1:-1]", consec((3, 3, 3)))
check_indexing("[None, -3:-1, None]", consec((6, 3)))
check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3)))
check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3)))
# dynamic expression usage
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
@ -257,10 +287,12 @@ class TestPythonBuiltinOP(JitTestCase):
def check_indexing(indexing, tensor, **kwargs):
indices_dict = kwargs
template = dedent("""
template = dedent(
"""
def func(x{formals}):
return x{expr}
""")
"""
)
formals = []
values = []
@ -268,17 +300,18 @@ class TestPythonBuiltinOP(JitTestCase):
formals.append(formal)
values.append(value)
formals = ''.join(map(', {}'.format, formals))
formals = "".join(map(", {}".format, formals))
inputs = [tensor] + values
self._check_code(template.format(formals=formals, expr=indexing),
"func", inputs)
self._check_code(
template.format(formals=formals, expr=indexing), "func", inputs
)
# Indexing with tensor (basic)
check_indexing('[i]', consec((3, 3)), i=torch.tensor([0]))
check_indexing('[i]', consec((3, 3)), i=torch.tensor(1))
check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2]))
check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0]))
check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
check_indexing("[i]", consec((3, 3)), i=torch.tensor([0]))
check_indexing("[i]", consec((3, 3)), i=torch.tensor(1))
check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2]))
check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0]))
check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
# NB: indexing with tensors and indexing with sequences can be implemented
# in a very similar way (sequences are converted to tensors), so only one
@ -290,49 +323,49 @@ class TestPythonBuiltinOP(JitTestCase):
inp = consec((4, 8, 5))
to_check = [
# [[0, 1, 3]]
['[i]', {'i': [0, 1, 3]}],
["[i]", {"i": [0, 1, 3]}],
# [[0, 2], [1, 3]]
['[i, j]', {'i': [0, 2], 'j': [1, 3]}],
["[i, j]", {"i": [0, 2], "j": [1, 3]}],
# [[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
['[i, j]', {'i': [[0, 1], [0, 1]], 'j': [[0, 1], [0, 1]]}],
["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}],
# [[0, 2], [1, 3], [1, 1]]
['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}],
["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}],
# [[0, 2], 1, [1, 1]]
['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}],
["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}],
# [:, :, [0, 3, 4]]
['[:, :, i]', {'i': [0, 3, 4]}],
["[:, :, i]", {"i": [0, 3, 4]}],
# [:, [2, 4, 5, 7], 2:4]
['[:, i, 2:4]', {'i': [0, 2, 3]}],
["[:, i, 2:4]", {"i": [0, 2, 3]}],
# [[2, 3], :, :]
['[i, :, :]', {'i': [2, 3]}],
["[i, :, :]", {"i": [2, 3]}],
# [:, [0, 2, 3], [1, 3, 4]]
['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
# [:, [0], [1, 2, 4]]
['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}],
["[:, i, j]", {"i": [0], "j": [1, 2, 4]}],
# [:, [0, 1, 3], [4]]
['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}],
["[:, i, j]", {"i": [0, 1, 3], "j": [4]}],
# [:, [[0, 1], [1, 0]], [[2, 3]]]
['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
# [:, [[0, 1], [2, 3]], [[0]]]
['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
# [:, [[5, 6]], [[0, 3], [4, 4]]]
['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}],
["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}],
# [[0, 2, 3], [1, 3, 4], :]
['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
# [0, [1, 2, 4], :]
['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}],
["[i, j, :]", {"i": 0, "j": [1, 2, 4]}],
# [[0, 1, 3], 4, :]
['[i, j, :]', {'i': [0, 1, 3], 'j': 4}],
["[i, j, :]", {"i": [0, 1, 3], "j": 4}],
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}],
["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}],
# [[[0, 1], [1, 0]], [[2, 3]], :]
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
# [[[0, 1], [2, 3]], [[0]], :]
['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
# [[[2, 1]], [[0, 3], [4, 4]], :]
['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}],
["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}],
# [[[2]], [[0, 3], [4, 1]], 0:2]
['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}],
["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}],
]
for expr, argdict in to_check:
@ -372,29 +405,35 @@ class TestPythonBuiltinOP(JitTestCase):
for _ in range(100):
indices = [random.choice(vals) for _ in range(4)]
indices[random.randint(0, len(indices) - 1)] = "..."
test_str = dedent("""
test_str = dedent(
"""
def f():
x = torch.ones(10, 9, 8, 7, 6)
return x{indices}.shape
""".format(indices=indices))
test_str = test_str.replace(r"'", r'')
""".format(
indices=indices
)
)
test_str = test_str.replace(r"'", r"")
scope = {}
execWrapper(test_str, globals(), scope)
cu = torch.jit.CompilationUnit(test_str)
res1 = cu.f()
res2 = scope['f']()
res2 = scope["f"]()
self.assertEqual(res1, res2)
def test_inf(self):
@torch.jit.script
def foo(a):
return a < float('inf')
return a < float("inf")
s = torch.rand(1)
self.assertTrue(foo(s))
@torch.jit.script
def bar(a):
return a > float('-inf')
return a > float("-inf")
s = torch.rand(1)
self.assertTrue(foo(s))
@ -414,19 +453,22 @@ class TestPythonBuiltinOP(JitTestCase):
def test_str_to_float(self):
@torch.jit.script
def foo(a):
return 0.5 == float('0.5 hello')
return 0.5 == float("0.5 hello")
s = torch.rand(1)
with self.assertRaisesRegex(RuntimeError, "could not convert string to float"):
self.assertTrue(foo(s))
@torch.jit.script
def foo(a):
return 0.5 == float('0.5')
return 0.5 == float("0.5")
s = torch.rand(1)
self.assertTrue(foo(s))
@torch.jit.script
def foo(a):
return 0. == float('0')
return 0.0 == float("0")
s = torch.rand(1)
self.assertTrue(foo(s))

View File

@ -1,22 +1,26 @@
# Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import IS_MACOS
import numpy as np
import unittest
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
import numpy as np
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPythonIr(JitTestCase):
def test_param_strides(self):
def trace_me(arg):
return arg
t = torch.zeros(1, 3, 16, 16)
traced = torch.jit.trace(trace_me, t)
value = list(traced.graph.param_node().outputs())[0]
@ -78,8 +82,12 @@ class TestPythonIr(JitTestCase):
g = foo.graph
muls = g.findAllNodes("aten::mul")
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
scalar_muls = filter(
lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
)
mul_constant_int = filter(
lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
)
for mul in mul_constant_int:
with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))

View File

@ -5,25 +5,31 @@ import re
import sys
import types
import typing
import typing_extensions
from typing import List, Dict, Optional, Tuple
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import torch
import torch.jit.frontend
import torch.nn as nn
import typing_extensions
from torch import Tensor
from torch.testing import FileCheck
from collections import OrderedDict
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything
from torch.testing._internal.jit_utils import (
_tmp_donotuse_dont_inline_everything,
JitTestCase,
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self):
@ -87,7 +93,9 @@ class TestRecursiveScript(JitTestCase):
return self.fn(x)
m = M(fn)
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to compile", "i_dont_exist"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to compile", "i_dont_exist"
):
torch.jit.script(m)
def test_init_error(self):
@ -119,12 +127,12 @@ class TestRecursiveScript(JitTestCase):
# sm1 was created while m had training = True
self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c.getattr('training'))
self.assertEqual(sm1.training, sm1._c.getattr("training"))
self.assertEqual(sm1(), 2)
# sm2 was created after m was eval'ed
self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c.getattr('training'))
self.assertEqual(sm2.training, sm2._c.getattr("training"))
self.assertEqual(sm2(), 0)
def test_module_name(self):
@ -165,7 +173,7 @@ class TestRecursiveScript(JitTestCase):
def test_constants_with_final(self):
class M1(torch.nn.Module):
x : torch.jit.Final[int]
x: torch.jit.Final[int]
def __init__(self):
super().__init__()
@ -177,7 +185,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(M1(), (torch.randn(2, 2),))
class M2(torch.nn.Module):
x : typing_extensions.Final[int]
x: typing_extensions.Final[int]
def __init__(self):
super().__init__()
@ -189,7 +197,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(M2(), (torch.randn(2, 2),))
class M3(torch.nn.Module):
x : typing.Final[int]
x: typing.Final[int]
def __init__(self):
super().__init__()
@ -206,12 +214,15 @@ class TestRecursiveScript(JitTestCase):
def unscriptable(self):
return "a" + 200
class TestModule(torch.nn.Module):
def forward(self, x):
return MyScriptClass()
with self.assertRaisesRegexWithHighlight(torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass"):
with self.assertRaisesRegexWithHighlight(
torch.jit.frontend.FrontendError,
"Cannot instantiate class",
"MyScriptClass",
):
t = torch.jit.script(TestModule())
def test_method_call(self):
@ -246,13 +257,13 @@ class TestRecursiveScript(JitTestCase):
print(m)
f = FileCheck()
f.check('MyModule')
f.check('Conv2d')
f.check('Linear')
f.check('Submodule')
f.check("MyModule")
f.check("Conv2d")
f.check("Linear")
f.check("Submodule")
f.run(out[0])
self.assertEqual(m.original_name, 'MyModule')
self.assertEqual(m.original_name, "MyModule")
def test_dir(self):
def test_module_dir(mod):
@ -260,8 +271,17 @@ class TestRecursiveScript(JitTestCase):
scripted_mod = torch.jit.script(mod)
dir_scripted = set(dir(scripted_mod))
# set not currently copied over
ignore_set = ["training", "__delitem__", "__setitem__", "clear", "items",
"keys", "pop", "update", "values"]
ignore_set = [
"training",
"__delitem__",
"__setitem__",
"clear",
"items",
"keys",
"pop",
"update",
"values",
]
for attr in dir_set:
if attr in ignore_set:
continue
@ -283,7 +303,9 @@ class TestRecursiveScript(JitTestCase):
linear = nn.Linear(10, 10)
test_module_dir(nn.Sequential(conv, linear))
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
test_module_dir(
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
)
def test_class_compile(self):
def other_fn(a: int, b: Tensor) -> Tensor:
@ -296,7 +318,6 @@ class TestRecursiveScript(JitTestCase):
def helper(self, a):
return self.x + a + other_fn(self.x, a)
class N(torch.nn.Module):
def forward(self, x):
b = B(x)
@ -411,7 +432,7 @@ class TestRecursiveScript(JitTestCase):
def test_module_basic(self):
class Other(torch.nn.Module):
__constants__ = ['x']
__constants__ = ["x"]
def __init__(self, x):
super().__init__()
@ -426,7 +447,6 @@ class TestRecursiveScript(JitTestCase):
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -439,7 +459,7 @@ class TestRecursiveScript(JitTestCase):
def test_module_function_export(self):
class Other(torch.nn.Module):
__constants__ = ['x']
__constants__ = ["x"]
def __init__(self, x):
super().__init__()
@ -453,7 +473,6 @@ class TestRecursiveScript(JitTestCase):
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -473,9 +492,7 @@ class TestRecursiveScript(JitTestCase):
def __init__(self):
super().__init__()
self.sequential = nn.Sequential(
Inner(),
Inner(),
nn.Sequential(Inner(), Inner())
Inner(), Inner(), nn.Sequential(Inner(), Inner())
)
self.module_list = nn.ModuleList([Inner(), Inner()])
@ -511,12 +528,14 @@ class TestRecursiveScript(JitTestCase):
self.sequential = nn.Sequential(
SeluButReluWhenScripted(),
SeluButReluWhenScripted(),
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
nn.Sequential(
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
),
shared,
)
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
shared,
SeluButReluWhenScripted()])
self.module_list = nn.ModuleList(
[SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
)
def forward(self, x):
for mod in self.module_list:
@ -553,7 +572,8 @@ class TestRecursiveScript(JitTestCase):
self.assertEqual(obj(1, 2), 3)
self.assertEqual(obj(1, 2, 3, 4), 10)
with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError, expected_regex="can't take variable number of arguments"
torch.jit.frontend.NotSupportedError,
expected_regex="can't take variable number of arguments",
):
torch.jit.script(obj)
@ -568,7 +588,10 @@ class TestRecursiveScript(JitTestCase):
self.assertEqual(jit_obj(1, 2), 3)
with self.assertRaisesRegex(
RuntimeError, expected_regex=re.escape("expected at most 2 argument(s) but received 4 argument(s)")
RuntimeError,
expected_regex=re.escape(
"expected at most 2 argument(s) but received 4 argument(s)"
),
):
jit_obj(1, 2, 3, 4)
@ -598,27 +621,26 @@ class TestRecursiveScript(JitTestCase):
def __getstate__(self):
return (self.a, self.inner)
untyped_values = (
('my_dict', {"I": "am", "a test": "test"}),
('my_float', 2.3),
('my_int', 99),
('my_bool', False),
('my_tuple', (1, 2, 3, 4)),
('my_list', [(1, 2), (3, 4)]),
("my_dict", {"I": "am", "a test": "test"}),
("my_float", 2.3),
("my_int", 99),
("my_bool", False),
("my_tuple", (1, 2, 3, 4)),
("my_list", [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)),
('my_int_list', [1, 2, 3, 4]),
("my_int_list", [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
('my_bool_list', [True, True, False, True]),
('my_float_list', [1., 2., 3., 4.]),
('my_str_list', ['hello', 'bye']),
("my_bool_list", [True, True, False, True]),
("my_float_list", [1.0, 2.0, 3.0, 4.0]),
("my_str_list", ["hello", "bye"]),
)
typed_values = (
('my_empty_list', []),
('my_empty_dict', {}),
('my_none', None),
('my_object', Foo()),
('my_object2', SFoo()),
("my_empty_list", []),
("my_empty_dict", {}),
("my_none", None),
("my_object", Foo()),
("my_object2", SFoo()),
)
class M(torch.nn.Module):
@ -659,11 +681,11 @@ class TestRecursiveScript(JitTestCase):
# since there's no string frontend for Python classes (so the `define`)
# trick doesn't work.
M.__annotations__ = {
'my_empty_list': List[int],
'my_empty_dict': Dict[str, int],
'my_none': Optional[int],
'my_object': Foo,
'my_object2': SFoo,
"my_empty_list": List[int],
"my_empty_dict": Dict[str, int],
"my_none": Optional[int],
"my_object": Foo,
"my_object2": SFoo,
}
m = M()
@ -694,7 +716,7 @@ class TestRecursiveScript(JitTestCase):
return self.encoder(x)
m = M()
self.checkModule(m, (torch.randn(5, 5), ))
self.checkModule(m, (torch.randn(5, 5),))
def test_inner_traced_module(self):
class Dummy(nn.Module):
@ -715,12 +737,13 @@ class TestRecursiveScript(JitTestCase):
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
dummies = nn.ModuleList([dummy])
model = Model(dummies)
self.checkModule(model, (torch.rand(5, 5), ))
self.checkModule(model, (torch.rand(5, 5),))
def test_script_loaded_module(self):
"""
Test that we can hold a loaded ScriptModule as a submodule.
"""
class Dummy(nn.Module):
def forward(self, x):
return x
@ -736,7 +759,7 @@ class TestRecursiveScript(JitTestCase):
def forward(self, input):
return self.encoder(input)
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
def test_optional_module(self):
class Dummy(nn.Module):

View File

@ -2,20 +2,23 @@
import os
import sys
from typing import List
import torch
from torch.testing import FileCheck
from typing import List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, freeze_rng_state
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRemoveMutation(JitTestCase):
def test_aten_inplace(self):
@ -26,7 +29,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_not_new_alias)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
@ -38,7 +41,7 @@ class TestRemoveMutation(JitTestCase):
# there is no functional equivalent of x[0] = ...
fn = torch.jit.script(test_no_lowering)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::copy_").run(graph)
self.assertEqual(fn(), test_no_lowering())
@ -50,7 +53,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_move_before_not_valid)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(), test_move_before_not_valid())
@ -63,7 +66,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::add_").run(graph)
self.assertEqual(test_successful(), fn())
@ -77,7 +80,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_intermediary_use)
graph = fn.graph
FileCheck().check_count("aten::add_", 2).run(graph)
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
# Unable to remove the second add_ because of the y = x + 4 use
# In the future we could duplicating the value of x as a temporary and replacing
# its intermediary use (so long as aliasing is safe)
@ -96,7 +99,7 @@ class TestRemoveMutation(JitTestCase):
out_eager = foo(torch.tensor(5), True)
foo_script = torch.jit.script(foo)
FileCheck().check("aten::add_").run(foo_script.graph)
self.run_pass('remove_mutation', foo_script.graph)
self.run_pass("remove_mutation", foo_script.graph)
FileCheck().check_not("aten::add_").run(foo_script.graph)
self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
@ -113,8 +116,8 @@ class TestRemoveMutation(JitTestCase):
y = x.add_(2)
return y, li
self.run_pass('inline', foo.graph)
self.run_pass('remove_mutation', foo.graph)
self.run_pass("inline", foo.graph)
self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
@torch.jit.script
@ -126,8 +129,8 @@ class TestRemoveMutation(JitTestCase):
z = x.add_(2)
return z
self.run_pass('inline', foo.graph)
self.run_pass('remove_mutation', foo.graph)
self.run_pass("inline", foo.graph)
self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
def test_special_mapped_op(self):
@ -140,7 +143,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
self.assertEqual(test_successful(), fn())
@ -154,8 +157,8 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful)
graph = fn.graph
self.run_pass('remove_mutation', graph)
FileCheck().check_not('aten::fill_').run(graph)
self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::fill_").run(graph)
def normal():
# NOTE: For some unknown reason, the
@ -167,7 +170,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(normal)
graph = fn.graph
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
FileCheck().check_not("normal_").run(graph)
with freeze_rng_state():
out_eager = normal()
@ -181,10 +184,12 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(successful_remove)
graph = fn.graph
self.run_pass('loop_unrolling', graph)
self.run_pass('remove_mutation', graph)
self.run_pass('constant_propagation', graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph)
self.run_pass("loop_unrolling", graph)
self.run_pass("remove_mutation", graph)
self.run_pass("constant_propagation", graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), successful_remove())
def intermediary_use():
@ -196,14 +201,14 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(intermediary_use)
graph = fn.graph
FileCheck().check("append").run(graph)
self.run_pass('remove_mutation', graph)
self.run_pass("remove_mutation", graph)
# it is possible to remove the append here but don't currently have the logic for it
FileCheck().check_not("append").run(graph)
self.assertEqual(intermediary_use(), fn())
def test_lists_insert(self):
def successful_remove():
a : List[int] = []
a: List[int] = []
a.insert(0, 1)
a.insert(0, 2)
a.insert(-10, 3)
@ -215,7 +220,9 @@ class TestRemoveMutation(JitTestCase):
graph = fn.graph
torch._C._jit_pass_remove_mutation(graph)
torch._C._jit_pass_constant_propagation(graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), fn())
def test_list_indexing_removal(self):
@ -271,6 +278,7 @@ class TestRemoveMutation(JitTestCase):
def test_common_pytorch_list_ops(self):
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
class OpMod(torch.nn.Module):
def __init__(self, op):
super().__init__()
@ -285,7 +293,7 @@ class TestRemoveMutation(JitTestCase):
torch_op = getattr(torch, op)
mod = OpMod(torch_op)
mod_script = torch.jit.script(mod)
self.run_pass('remove_mutation', mod_script.forward.graph)
self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
self.assertEqual(mod(), mod_script())
@ -299,7 +307,6 @@ class TestRemoveMutation(JitTestCase):
self.assertEqual(sums, [ten.sum() for ten in result])
@torch.jit.script
def test_multiple_uses():
x = torch.tensor([1, 2, 3, 4])
@ -307,5 +314,5 @@ class TestRemoveMutation(JitTestCase):
y = [x, x]
return torch.cat(y), y
self.run_pass('remove_mutation', mod_script.forward.graph)
self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check("aten::add_").run(test_multiple_uses.graph)

View File

@ -8,12 +8,12 @@ from typing import NamedTuple, Optional
import torch
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName, skipIfTorchDynamo
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
if __name__ == "__main__":
@ -439,7 +439,7 @@ class TestSaveLoad(JitTestCase):
global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple):
a: 'int'
a: "int"
class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
@ -608,7 +608,6 @@ class TestSaveLoad(JitTestCase):
self.assertTrue(m_params["bar.bias"].is_cpu)
self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
def test_save_load_with_saved_traced_inputs(self):
"""
Check that saving and loading with traced inputs works as expected
@ -637,14 +636,18 @@ class TestSaveLoad(JitTestCase):
# Validate that with no input specified the traced inputs are stored
traced_module = torch.jit.trace(module, input_tensor)
traced_inputs = list(traced_module.graph.inputs())
self.assertEqual(traced_module._c._retrieve_traced_inputs()['forward'], [input_tensor])
self.assertEqual(
traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor]
)
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
traced_module.save(path)
loaded_module = torch.jit.load(path, _restore_shapes=True)
loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
self.assertEqual(traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes())
self.assertEqual(
traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()
)
# Validate that if no shapes are requested previous functionality remains
loaded_module = torch.jit.load(path)
loaded_inputs = list(loaded_module.graph.inputs())
@ -672,7 +675,7 @@ class TestSaveLoad(JitTestCase):
"1000": (
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
torch.tensor([])
torch.tensor([]),
)
}
traced_inputs, loaded_inputs = get_loaded_inputs(input1)
@ -683,28 +686,32 @@ class TestSaveLoad(JitTestCase):
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])
torch.tensor([2.0, 3.0]),
)
}
traced_inputs, loaded_inputs = get_loaded_inputs(input2)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list
input3 = [torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])]
input3 = [
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0]),
]
traced_inputs, loaded_inputs = get_loaded_inputs(input3)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list of dict of list
input4 = [{
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0])
)
}]
input4 = [
{
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0]),
)
}
]
traced_inputs, loaded_inputs = get_loaded_inputs(input4)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
@ -715,14 +722,17 @@ class TestSaveLoad(JitTestCase):
Check if the model with string > 4GB can be loaded.
"""
import psutil
if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024:
# Profiled the test execution, and got this number to be safe to run the test
self.skipTest("Doesn't have enough memory to run test_save_load_large_string_attribute")
self.skipTest(
"Doesn't have enough memory to run test_save_load_large_string_attribute"
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.x = "x" * (2 ** 32 + 1)
self.x = "x" * (2**32 + 1)
def forward(self, i) -> int:
return len(self.x) + i.numel()
@ -793,12 +803,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
@ -846,12 +852,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
@ -931,12 +933,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x = self.first(x)
@ -1035,12 +1033,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module):
def __init__(self):
super().__init__()
self.add_module(
"second", torch.jit.load(second_saved_module)
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
self.add_module("second", torch.jit.load(second_saved_module))
self.add_module("first", torch.jit.load(first_saved_module))
def forward(self, x):
x, named_tuple_1 = self.first(x)
@ -1118,18 +1112,18 @@ class TestSaveLoadFlatbuffer(JitTestCase):
first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer(
first_script_module, first_saved_module)
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_saved_module.seek(0)
ff_info = torch.jit._serialization.get_flatbuffer_module_info(first_saved_module)
self.assertEqual(ff_info['bytecode_version'], 9)
self.assertEqual(ff_info['operator_version'], 1)
self.assertEqual(ff_info['type_names'], set())
self.assertEqual(ff_info['opname_to_num_args'], {'aten::linear': 3})
self.assertEqual(len(ff_info['function_names']), 1)
self.assertTrue(next(iter(ff_info['function_names'])).endswith('forward'))
ff_info = torch.jit._serialization.get_flatbuffer_module_info(
first_saved_module
)
self.assertEqual(ff_info["bytecode_version"], 9)
self.assertEqual(ff_info["operator_version"], 1)
self.assertEqual(ff_info["type_names"], set())
self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3})
self.assertEqual(len(ff_info["function_names"]), 1)
self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward"))
def test_save_load_params_buffers_submodules(self):
"""
@ -1179,7 +1173,6 @@ class TestSaveLoadFlatbuffer(JitTestCase):
self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer)
def test_save_load_with_extra_files(self):
"""
Check that parameters, buffers, and submodules are the same after loading.
@ -1194,7 +1187,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
extra_files = {"abc.json": b"[1,2,3]"}
script_module_io = script_module._save_to_buffer_for_lite_interpreter(
_extra_files=extra_files, _use_flatbuffer=True)
_extra_files=extra_files, _use_flatbuffer=True
)
re_extra_files = {}
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)

View File

@ -1,20 +1,21 @@
# Owner(s): ["oncall: jit"]
from itertools import product as product
import io
import os
import sys
import hypothesis.strategies as st
from hypothesis import example, settings, given
from itertools import product as product
from typing import Union
import hypothesis.strategies as st
import torch
from hypothesis import example, given, settings
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
@ -23,6 +24,7 @@ if __name__ == "__main__":
"instead."
)
class TestSaveLoadForOpVersion(JitTestCase):
# Helper that returns the module after saving and loading
def _save_load_module(self, m):
@ -53,7 +55,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count)
"""
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type,
@ -62,16 +63,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
div behavior has not yet been updated.
"""
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor(self, sample_input):
def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
return self.divide(other, rounding_mode="trunc")
# Tensor x Tensor
class MyModule(torch.nn.Module):
@ -85,7 +91,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
# Loads historic module
try:
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -108,16 +116,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module, historic_div)
_helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor_inplace(self, sample_input):
def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
return self.divide_(other, rounding_mode="trunc")
class MyModule(torch.nn.Module):
def forward(self, a, b):
@ -126,7 +139,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -151,16 +166,25 @@ class TestSaveLoadForOpVersion(JitTestCase):
a = torch.tensor((val_a,))
_helper(current_mobile_module, torch.Tensor.div_)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor_out(self, sample_input):
def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
if (
self.is_floating_point()
or other.is_floating_point()
or out.is_floating_point()
):
return torch.true_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc')
return torch.divide(self, other, out=out, rounding_mode="trunc")
class MyModule(torch.nn.Module):
def forward(self, a, b, out):
@ -168,7 +192,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -179,6 +205,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
b = torch.tensor((val_b,))
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
def _helper(m, fn):
fn_result = None
if fn is torch.div:
@ -196,9 +223,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module, historic_div_out)
_helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar(self, sample_input):
@ -208,7 +240,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
def historic_div_scalar_int(self, other: int):
if self.is_floating_point():
return torch.true_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc')
return torch.divide(self, other, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float):
@ -220,9 +252,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
pytorch_test_dir
+ "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -249,9 +285,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module_int, historic_div_scalar_int)
_helper(current_mobile_module_int, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar_reciprocal(self, sample_input):
@ -261,7 +302,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point():
return other / self
return torch.divide(other, self, rounding_mode='trunc')
return torch.divide(other, self, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float):
@ -273,9 +314,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -311,9 +356,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module_int, current_mobile_module_int)
_helper(current_mobile_module_int, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
@settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar_inplace(self, sample_input):
@ -324,7 +374,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
if self.is_floating_point():
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
return self.divide_(other, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float):
@ -338,9 +388,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
@ -378,14 +432,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
try:
v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl"
)
except Exception as e:
self.skipTest("Failed to load fixture!")
current_mobile_module = self._save_load_mobile_module(MyModule)
def _helper(m, fn):
vals = (5., 3, 2., 7)
vals = (5.0, 3, 2.0, 7)
m_result = m(*vals)
fn_result = fn(*vals)
for mr, hr in zip(m_result, fn_result):
@ -395,13 +451,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_linspace(self):
class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
def forward(
self, a: Union[int, float, complex], b: Union[int, float, complex]
):
c = torch.linspace(a, b, steps=5)
d = torch.linspace(a, b, steps=100)
return c, d
scripted_module = torch.jit.load(
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl")
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
)
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -410,7 +469,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
for a, b in sample_inputs:
(output_with_step, output_without_step) = v7_mobile_module(a, b)
(current_with_step, current_without_step) = current_mobile_module(a, b)
# when no step is given, should have used 100
@ -422,10 +481,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_linspace_out(self):
class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.linspace(a, b, steps=100, out=out)
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
)
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -433,12 +499,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)),
(
3,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
-10,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
4.0,
6.0,
torch.empty((100,), dtype=torch.float64),
torch.empty((100,), dtype=torch.float64),
),
(
3 + 4j,
4 + 5j,
torch.empty((100,), dtype=torch.complex64),
torch.empty((100,), dtype=torch.complex64),
),
)
for (start, end, out_for_old, out_for_new) in sample_inputs:
for start, end, out_for_old, out_for_new in sample_inputs:
output = v7_mobile_module(start, end, out_for_old)
output_current = current_mobile_module(start, end, out_for_new)
# when no step is given, should have used 100
@ -448,13 +534,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_logspace(self):
class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
def forward(
self, a: Union[int, float, complex], b: Union[int, float, complex]
):
c = torch.logspace(a, b, steps=5)
d = torch.logspace(a, b, steps=100)
return c, d
scripted_module = torch.jit.load(
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl")
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
)
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -463,7 +552,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
for a, b in sample_inputs:
(output_with_step, output_without_step) = v8_mobile_module(a, b)
(current_with_step, current_without_step) = current_mobile_module(a, b)
# when no step is given, should have used 100
@ -475,10 +564,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_logspace_out(self):
class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.logspace(a, b, steps=100, out=out)
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
)
loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
@ -486,12 +582,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)),
(
3,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
-10,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
4.0,
6.0,
torch.empty((100,), dtype=torch.float64),
torch.empty((100,), dtype=torch.float64),
),
(
3 + 4j,
4 + 5j,
torch.empty((100,), dtype=torch.complex64),
torch.empty((100,), dtype=torch.complex64),
),
)
for (start, end, out_for_old, out_for_new) in sample_inputs:
for start, end, out_for_old, out_for_new in sample_inputs:
output = v8_mobile_module(start, end, out_for_old)
output_current = current_mobile_module(start, end, out_for_new)
# when no step is given, should have used 100

View File

@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class Sequence(nn.Module):
def __init__(self):
@ -38,8 +41,8 @@ class Sequence(nn.Module):
outputs = torch.cat(outputs, dim=1)
return outputs
class TestScriptProfile(JitTestCase):
class TestScriptProfile(JitTestCase):
def test_basic(self):
seq = torch.jit.script(Sequence())
p = torch.jit._ScriptProfile()
@ -57,6 +60,7 @@ class TestScriptProfile(JitTestCase):
@torch.jit.script
def fn():
_ = seq(torch.rand((10, 100)))
fn()
p.disable()
@ -83,7 +87,7 @@ class TestScriptProfile(JitTestCase):
seq = Sequence()
@torch.jit.script
def fn(max : int):
def fn(max: int):
_ = seq(torch.rand((10, max)))
p = torch.jit._ScriptProfile()

View File

@ -3,22 +3,24 @@
import os
import sys
import warnings
from typing import Dict, List, Optional
import torch
from typing import List, Dict, Optional
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
# reassigning a non-empty Tuple to an attribute previously typed
# as containing an empty Tuple SHOULD fail. See note in `_check.py`
@ -81,7 +83,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
def test_annotated_class_level_annotation_only(self):
class M(torch.nn.Module):
x: List[int]
def __init__(self):
@ -96,10 +97,8 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.checkModule(M(), ([1, 2, 3],))
assert len(w) == 0
def test_annotated_class_level_annotation_and_init_annotation(self):
class M(torch.nn.Module):
x: List[int]
def __init__(self):
@ -116,7 +115,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
def test_annotated_class_level_jit_annotation(self):
class M(torch.nn.Module):
x: List[int]
def __init__(self):
@ -141,12 +139,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Tried to set nonexistent attribute",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_empty_dict(self):
@ -159,12 +160,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Tried to set nonexistent attribute",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_empty_optional(self):
@ -177,12 +181,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Wrong type for attribute assignment",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_with_jit_empty_list(self):
@ -195,12 +202,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Tried to set nonexistent attribute",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_with_jit_empty_dict(self):
@ -213,12 +223,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Tried to set nonexistent attribute",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_with_jit_empty_optional(self):
@ -231,12 +244,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Wrong type for attribute assignment",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())
def test_annotated_with_torch_jit_import(self):
@ -251,10 +267,13 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x
return 1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Wrong type for attribute assignment",
"self.x = x"):
with self.assertWarnsRegex(UserWarning, "doesn't support "
"instance-level annotations on "
"empty non-base types"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
):
with self.assertWarnsRegex(
UserWarning,
"doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M())

View File

@ -2,19 +2,22 @@
import os
import sys
from typing import List
import torch
from typing import List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript
class TestSlice(JitTestCase):
@ -22,7 +25,9 @@ class TestSlice(JitTestCase):
def slice_kwarg(x: List[int]):
return x[slice(1, stop=2)]
with self.assertRaisesRegex(RuntimeError, "Slice does not accept any keyword arguments"):
with self.assertRaisesRegex(
RuntimeError, "Slice does not accept any keyword arguments"
):
torch.jit.script(slice_kwarg)
def test_slice_three_nones(self):
@ -46,11 +51,13 @@ class TestSlice(JitTestCase):
def test_slice_stop_only(self):
def fn(x: List[int]):
return x[slice(5)]
self.checkScript(fn, (range(10),))
def test_slice_stop_only_with_nones(self):
def fn(x: List[int]):
return x[slice(None, 5, None)]
self.checkScript(fn, (range(10),))
def test_slice_start_stop(self):
@ -136,8 +143,8 @@ class TestSlice(JitTestCase):
num_outputs = {len(x.output().type().elements()) for x in slices}
# there should be only one tupleSlice with length of 2
self.assertTrue(num_outputs == {2})
self.run_pass('lower_all_tuples', tuple_graph)
self.assertTrue('Tuple' not in str(tuple_graph))
self.run_pass("lower_all_tuples", tuple_graph)
self.assertTrue("Tuple" not in str(tuple_graph))
def test_module_list_slicing(self):
class Bar(torch.nn.Module):

View File

@ -1,8 +1,9 @@
# Owner(s): ["oncall: jit"]
import io
import torch
import unittest
import torch
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
from torch.testing._internal.jit_utils import JitTestCase
@ -70,9 +71,7 @@ class TestSparse(JitTestCase):
self.a = torch.rand(4, 4).to_sparse_csr()
self.b = torch.rand(4, 4).to_sparse_csr()
def forward(self, x):
return x.matmul(self.a).matmul(self.b)
x = torch.rand(4, 4).to_sparse_csr()

View File

@ -2,65 +2,76 @@
import os
import sys
from typing import List
import torch
from typing import List
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestStringFormatting(JitTestCase):
def test_modulo_operator(self):
def fn(dividend: int, divisor: int) -> int:
return dividend % divisor
self.checkScript(fn, (5, 2))
def test_string_interpolation_with_string_placeholder_and_string_variable(self):
def fn(arg1: str):
return "%s in template" % arg1
self.checkScript(fn, ("foo",))
def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
def test_string_interpolation_with_string_placeholder_and_format_string_variable(
self,
):
def fn(arg1: str):
return arg1 % "foo"
self.checkScript(fn, ("%s in template",))
def test_string_interpolation_with_double_percent_in_string(self):
def fn(arg1: str):
return "%s in template %%" % arg1
self.checkScript(fn, ("foo",))
def test_string_interpolation_with_percent_in_string(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%s in template %" % arg1 # noqa: F501
return "%s in template %" % arg1 # noqa: F501
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Incomplete format specifier",
"\"%s in template %\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Incomplete format specifier", '"%s in template %" % arg1'
):
fn("foo")
def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%s in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%d in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_alternate_digit_placeholder(self):
def fn(arg1: int) -> str:
return "%i in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
@ -68,9 +79,11 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str:
return "%d in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"%d requires a number for formatting, but got String",
"\"%d in template\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"%d requires a number for formatting, but got String",
'"%d in template" % arg1',
):
fn("1")
def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
@ -78,39 +91,51 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str:
return "%e in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"%e requires a number for formatting, but got String",
"\"%e in template\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"%e requires a number for formatting, but got String",
'"%e in template" % arg1',
):
fn("1")
def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(
self,
):
def fn(arg1: int) -> str:
return "%e in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(
self,
):
def fn(arg1: int) -> str:
return "%E in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_float_placeholder_and_float_variable(self):
def fn(arg1: float) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1.0,))
def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%f in template" % arg1
self.checkScript(fn, (1,))
def test_string_interpolation_with_char_placeholder_and_char_variable(self):
def fn(arg1: str) -> str:
return "%c in template" % arg1
self.checkScript(fn, ("a",))
def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str:
return "%c in template" % arg1
self.checkScript(fn, (97,))
def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
@ -118,19 +143,23 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str:
return "%c in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"%c requires an int or char for formatting, but got String",
"\"%c in template\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"%c requires an int or char for formatting, but got String",
'"%c in template" % arg1',
):
fn("foo")
def test_string_interpolation_with_multiple_placeholders(self):
def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3)
self.checkScript(fn, ("foo", 1, 1))
def test_string_interpolation_with_subscript(self):
def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0]
self.checkScript(fn, (["foo", "bar"],))
def test_string_interpolation_with_too_few_arguments(self):
@ -138,27 +167,33 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str:
return "%s %s in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Too few arguments for format string",
"\"%s %s in template\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Too few arguments for format string",
'"%s %s in template" % arg1',
):
fn("foo")
def test_string_interpolation_with_too_many_arguments(self):
@torch.jit.script
def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2) # noqa: F507
return "%s in template" % (arg1, arg2) # noqa: F507
with self.assertRaisesRegexWithHighlight(RuntimeError,
"Too many arguments for format string",
"\"%s in template\" % (arg1, arg2"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Too many arguments for format string",
'"%s in template" % (arg1, arg2',
):
fn("foo", "bar")
def test_string_interpolation_with_unknown_format_specifier(self):
@torch.jit.script
def fn(arg1: str) -> str:
return "%a in template" % arg1 # noqa: F501
return "%a in template" % arg1 # noqa: F501
with self.assertRaisesRegexWithHighlight(RuntimeError,
"The specifier %a is not supported in TorchScript format strings",
"\"%a in template\" % arg1"):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"The specifier %a is not supported in TorchScript format strings",
'"%a in template" % arg1',
):
fn("foo")

View File

@ -3,29 +3,36 @@
import operator
import unittest
from textwrap import dedent
from typing import Any, List
import torch
from torch import nn, Tensor
from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from typing import List, Any
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase):
def setUp(self):
super(JitTestCase, self).setUp()
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
self.prev_symbolic_shapes_test_enabled = (
torch._C._jit_symbolic_shapes_test_mode_enabled()
)
torch._C._jit_set_symbolic_shapes_test_mode(True)
def tearDown(self):
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
torch._C._jit_set_symbolic_shapes_test_mode(
self.prev_symbolic_shapes_test_enabled
)
def test_shape_analysis(self):
@torch.jit.script
@ -115,7 +122,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def neg_to_one(li):
return [elem if elem >= 0 else -1 for elem in li]
self.assertEqual(neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1])
self.assertEqual(
neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]
)
if_out = next(foo.graph.findNode("prim::If").outputs())
self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1])
@ -135,9 +144,7 @@ class TestSymbolicShapeAnalysis(JitTestCase):
y = x.mul_(2)
return y
unary_ops = [
mul_inplace
]
unary_ops = [mul_inplace]
for fn in unary_ops:
# t = torch.jit.trace(fn, torch.rand([4, 4])) # For some reason tracing is erroring out.
t = torch.jit.script(fn)
@ -202,7 +209,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
torch._C._jit_pass_propagate_shapes_on_graph(graph)
self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
self.assertEqual(
next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]
)
def test_adaptive_avg_pool2d(self):
inps = [
@ -227,25 +236,105 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
def test_conv_deconv(self):
for inp_shape, weight_shape, bias, stride, padding, output_padding, dilation, groups, mod in [
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
([32, 16, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv_transpose1d),
([1, 32, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv2d),
([1, 30, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv_transpose2d),
([3, 14, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 2, 1, 2, torch.nn.functional.conv3d),
([3, 2, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 0, 1, 2, torch.nn.functional.conv_transpose3d)]:
for (
inp_shape,
weight_shape,
bias,
stride,
padding,
output_padding,
dilation,
groups,
mod,
) in [
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
(
[32, 16, 10],
[16, 3, 3],
None,
2,
2,
1,
1,
2,
torch.nn.functional.conv_transpose1d,
),
(
[1, 32, 5, 10],
[30, 16, 3, 3],
None,
[2, 2],
[0, 0],
0,
1,
2,
torch.nn.functional.conv2d,
),
(
[1, 30, 5, 10],
[30, 16, 3, 3],
None,
[2, 2],
[0, 0],
0,
1,
2,
torch.nn.functional.conv_transpose2d,
),
(
[3, 14, 10, 66, 55],
[2, 7, 7, 4, 4],
None,
1,
1,
2,
1,
2,
torch.nn.functional.conv3d,
),
(
[3, 2, 10, 66, 55],
[2, 7, 7, 4, 4],
None,
1,
1,
0,
1,
2,
torch.nn.functional.conv_transpose3d,
),
]:
inp = torch.rand(inp_shape)
weight = torch.rand(weight_shape)
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]:
if mod in [
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
]:
res = mod(inp, weight, bias, stride, padding, dilation, groups).size()
else:
res = mod(inp, weight, bias, stride, padding, output_padding, dilation, groups).size()
res = mod(
inp, weight, bias, stride, padding, output_padding, dilation, groups
).size()
def foo(inp, weight):
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]:
if mod in [
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
]:
return mod(inp, weight, bias, stride, padding, dilation, groups)
else:
return mod(inp, weight, bias, stride, padding, output_padding, dilation, groups)
return mod(
inp,
weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
)
fn = torch.jit.trace(foo, (inp, weight))
torch._C._jit_erase_non_input_shape_information(fn.graph)
@ -280,33 +369,58 @@ class TestSymbolicShapeAnalysis(JitTestCase):
]
for inp in inps:
funcs_template = dedent('''
funcs_template = dedent(
"""
def func():
return torch.arange({args})
''')
"""
)
inp_s = str(inp)[1:-1] # remove tuple parens
funcs_str = funcs_template.format(args=inp_s)
scope = {}
execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str)
self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
self.checkShapeAnalysis(
list(cu.func().size()),
cu.func.graph,
assert_propagation=True,
constant_prop=False,
)
def test_shape_embedding_bag(self):
# TODO: merge into opinfos, having difficulties there
with torch.no_grad():
def make_arg(shape, low=None, high=None):
return make_tensor(shape, device='cpu', dtype=torch.int64,
low=low, high=high, requires_grad=False)
return make_tensor(
shape,
device="cpu",
dtype=torch.int64,
low=low,
high=high,
requires_grad=False,
)
nn_inps = (
(make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)),
(
make_arg((40,), 0, 9),
torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0),
),
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
(make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)),
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
(make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
(make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2.,
norm_type=.5, scale_grad_by_freq=False, sparse=True)),
(
make_arg((2,), 0, 1),
torch.nn.Embedding.from_pretrained(
torch.arange(6.0).view(2, 3),
max_norm=2.0,
norm_type=0.5,
scale_grad_by_freq=False,
sparse=True,
),
),
)
for inp, module in nn_inps:
@ -326,14 +440,16 @@ class TestSymbolicShapeAnalysis(JitTestCase):
fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False)
self.checkShapeAnalysis(
out_size, fn.graph, assert_propagation=True, constant_prop=False
)
def test_shape_concat(self):
# TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)
class CatMod(nn.Module):
__constants__ = ['dim']
__constants__ = ["dim"]
def __init__(self, dim=0):
super().__init__()
@ -374,16 +490,23 @@ class TestSymbolicShapeAnalysis(JitTestCase):
# Also, as the return shapes are the input, weight, and bias shape, there is no point
# in a really complicated test
input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True)
weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True)
input = torch.randn(
(16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True
)
weight = torch.randn(
(8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True
)
out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
@torch.jit.script
def conv_bwd(input, weight, grad):
bias_sizes = [8, ]
bias_sizes = [
8,
]
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
return torch.ops.aten.convolution_backward(
grad, input, weight, bias_sizes, *args
)
self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
@ -391,15 +514,19 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def conv_bwd_2(input, weight, grad):
bias_sizes = None
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
return torch.ops.aten.convolution_backward(
grad, input, weight, bias_sizes, *args
)
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
def test_returning_input_symbolic_shapes(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
inps = list(mm.graph.inputs())
inps[1].setType(inps[1].type().with_sizes([None, None, None, None]))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
)
g = shape_compute_graph.partial_eval_shape_graph()
# to make into a jit function cant have multiple outputs
g.makeMultiOutputIntoTuple()
@ -412,8 +539,12 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def test_partial_eval_graph_conv(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
output_sizes = mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
)
output_sizes = (
mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
)
# calculating 0, 2 and 3 index
for i in [0, 2, 3]:
self.assertTrue(output_sizes[i] < 0)
@ -428,7 +559,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
self.assertEqual(o, oe)
def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs):
def checkSymShapeCompute(
self, shape_compute_graph, nodes, node_output_sizes, shape_inputs
):
g = shape_compute_graph.partial_eval_shape_graph()
self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
@ -451,27 +584,49 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
def test_partial_eval_stitching(self):
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
conv1 = torch.nn.Conv2d(
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
max_pool = torch.nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
)
conv2 = nn.Conv2d(
64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
mod = torch.jit.freeze(
torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())
)
conv1_output = conv1(torch.rand(1, 3, 224, 224))
max_pool_output = max_pool(conv1_output)
conv2_output = conv2(max_pool_output)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()]
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],))
shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
)
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(
mod.graph.findAllNodes("aten::conv2d")
)
output_shapes = [
max_pool_output.size(),
conv1_output.size(),
conv2_output.size(),
]
self.checkSymShapeCompute(
shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)
)
def test_refinement_through_graph_stitching(self):
class TwoConvs(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.conv2 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.conv1 = torch.nn.Conv2d(
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.conv2 = torch.nn.Conv2d(
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
def forward(self, x):
a = self.conv1(x)
@ -495,18 +650,29 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(out1, out2)
def test_stitching_multi_output(self):
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True)
max_pool = torch.nn.MaxPool2d(
kernel_size=3,
stride=2,
padding=1,
dilation=1,
ceil_mode=False,
return_indices=True,
)
tensor = torch.rand(1, 3, 224, 224)
mod = torch.jit.trace(max_pool, (tensor,))
mod = torch.jit.freeze(mod.eval())
inp = list(mod.graph.inputs())[1]
inp.setType(inp.type().with_sizes([None, None, None, None]))
output_tensor = list(mod(tensor)[0].size())
self.run_pass('lower_all_tuples', mod.graph)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
self.run_pass("lower_all_tuples", mod.graph)
shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
)
max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
outs = list(max_pool_node.outputs())
self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes())
self.assertEqual(
outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()
)
g = shape_compute_graph.partial_eval_shape_graph()
# to make into a jit function cant have multiple outputs
g.makeMultiOutputIntoTuple()
@ -528,7 +694,6 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(out, [-2, -3])
def test_stitching_concat(self):
@torch.jit.script
def foo1(a, b, x, y):
return (a / b) + torch.cat([x, y])
@ -542,15 +707,25 @@ class TestSymbolicShapeAnalysis(JitTestCase):
for inp in foo.graph.inputs():
inp.setType(inp.type().with_sizes([None, None]))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]
shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
foo.graph
)
)
nodes = (
[g.findNode("aten::div")]
+ [g.findNode("aten::add")]
+ [g.findNode("aten::cat")]
)
inps = [1, 10], [20, 10], [15, 1], [5, 1]
output_shapes = [[20, 10], [20, 10], [20, 1]]
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
@unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python")
@unittest.skipIf(
not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python"
)
def test_shape_function_includes(self):
inp_shape = [1, 16, 5, 10]
weight_shape = [33, 16, 3, 3]
@ -559,7 +734,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
padding = [0, 0]
dilation = [1, 1]
groups = 1
res = torch.jit._shapes.conv2d(inp_shape, weight_shape, bias, stride, padding, dilation, groups)
res = torch.jit._shapes.conv2d(
inp_shape, weight_shape, bias, stride, padding, dilation, groups
)
self.assertEqual(res, [1, 33, 2, 4])
m1_shape = [10, 20]
@ -580,8 +757,11 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def wrong_input_types(x, y):
x: List[int] = []
return x
with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph)
torch._C._jit_register_shape_compute_graph_for_node(
node, wrong_input_types.graph
)
@torch.jit.script
def wrong_output_types(x: List[int], y: List[int]):
@ -589,7 +769,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
return x
with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph)
torch._C._jit_register_shape_compute_graph_for_node(
node, wrong_output_types.graph
)
@torch.jit.script
def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
@ -597,7 +779,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
return x
with self.assertRaises(RuntimeError) as error:
torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph)
torch._C._jit_register_shape_compute_graph_for_node(
node, too_many_inputs.graph
)
self.assertTrue("fewer arguments than schema" in str(error.exception))
@ -608,9 +792,22 @@ class TestSymbolicShapeAnalysis(JitTestCase):
inputs = list(foo.graph.inputs())
inputs[0].setType(inputs[0].type().with_sizes([8, 2]))
inputs[1].setType(inputs[1].type().with_sizes([8,]))
inputs[1].setType(
inputs[1]
.type()
.with_sizes(
[
8,
]
)
)
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
self.assertEqual(next(foo.graph.outputs()).type().sizes(), [8,])
self.assertEqual(
next(foo.graph.outputs()).type().sizes(),
[
8,
],
)
def test_squeeze_dims(self):
@torch.jit.script

View File

@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorCreationOps(JitTestCase):
"""
@ -27,7 +30,7 @@ class TestTensorCreationOps(JitTestCase):
# as integers, which are not comparable against eager torch.dtype.
assert perm.dtype == torch.int64
self.checkScript(randperm, (3, ))
self.checkScript(randperm, (3,))
def test_randperm_specifed_dtype(self):
def randperm(x: int):
@ -36,7 +39,7 @@ class TestTensorCreationOps(JitTestCase):
# as integers, which are not comparable against eager torch.dtype.
assert perm.dtype == torch.float
self.checkScript(randperm, (3, ))
self.checkScript(randperm, (3,))
def test_triu_indices_default_dtype(self):
def triu_indices(rows: int, cols: int):

View File

@ -8,8 +8,8 @@ import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
@ -18,6 +18,7 @@ if __name__ == "__main__":
"instead."
)
class TestTensorMethods(JitTestCase):
def test_getitem(self):
def tensor_getitem(inp: torch.Tensor):
@ -25,7 +26,7 @@ class TestTensorMethods(JitTestCase):
return inp.__getitem__(indices)
inp = torch.rand(3, 4)
self.checkScript(tensor_getitem, (inp, ))
self.checkScript(tensor_getitem, (inp,))
scripted = torch.jit.script(tensor_getitem)
FileCheck().check("aten::index").run(scripted.graph)
@ -35,5 +36,6 @@ class TestTensorMethods(JitTestCase):
return inp.__getitem__()
with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"):
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
):
torch.jit.script(tensor_getitem_invalid)

View File

@ -1,27 +1,27 @@
# Owner(s): ["oncall: jit"]
import copy
import io
import os
import sys
import copy
import unittest
from typing import Optional
import torch
from typing import Optional
from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
find_library_location,
)
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
@ -30,14 +30,15 @@ if __name__ == "__main__":
"instead."
)
@skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase):
def setUp(self):
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location('libtorchbind_test.so')
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location('torchbind_test.dll')
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
def test_torchbind(self):
@ -50,15 +51,17 @@ class TestTorchbind(JitTestCase):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1)
return val
test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo')
val.increment("foo")
def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop()
test_equality(f, lambda x: x)
def f():
@ -66,6 +69,7 @@ class TestTorchbind(JitTestCase):
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop())
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)
# test nn module with prepare_scriptable function
@ -116,8 +120,11 @@ class TestTorchbind(JitTestCase):
scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\
.check('prim::CallMethod[name="__init__"]')
fc = (
FileCheck()
.check("prim::CreateObject()")
.check('prim::CallMethod[name="__init__"]')
)
fc.run(str(scripted.graph))
out = scripted()
self.assertEqual(out.pop(), "mom")
@ -167,7 +174,7 @@ class TestTorchbind(JitTestCase):
out, result = scripted()
self.assertEqual(result, 10)
with self.assertRaisesRegex(RuntimeError, 'can\'t set attribute'):
with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
out.y = 5
def foo_not_setter():
@ -177,9 +184,11 @@ class TestTorchbind(JitTestCase):
# getY method intentionally adds 4 to x
return fooGetterSetter.y
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooGetterSetter.y = old + 4'):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Tried to set read-only attribute: y",
"fooGetterSetter.y = old + 4",
):
scripted = torch.jit.script(foo_not_setter)
def test_torchbind_def_property_readwrite(self):
@ -196,9 +205,9 @@ class TestTorchbind(JitTestCase):
fooReadWrite.y = 5
return fooReadWrite
with self.assertRaisesRegexWithHighlight(RuntimeError,
'Tried to set read-only attribute: y',
'fooReadWrite.y = 5'):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
):
scripted = torch.jit.script(foo_readwrite_error)
def test_torchbind_take_instance_as_method_arg(self):
@ -250,7 +259,9 @@ class TestTorchbind(JitTestCase):
return self.foo_mod.info()
def to_ivalue(self):
torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1)
torchbind_model = torch.classes._TorchScriptTesting._Foo(
self.foo_mod.info(), 1
)
return FooBar(torchbind_model)
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
@ -338,7 +349,7 @@ class TestTorchbind(JitTestCase):
self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pass_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, 'but instead found type \'Tensor\''):
with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
def test_torchbind_tracing_nested(self):
@ -368,12 +379,15 @@ class TestTorchbind(JitTestCase):
self.assertEqual(nt_loaded.pop(), exp)
def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
with self.assertRaisesRegex(
RuntimeError,
"Tried to instantiate class 'foo.IDontExist', but it does not exist!",
):
torch.classes.foo.IDontExist(3, 4, 5)
def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString]
foo: Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self):
super().__init__()
@ -384,13 +398,13 @@ class TestTorchbind(JitTestCase):
if foo_obj is not None:
return foo_obj.pop()
else:
return '<None>'
return "<None>"
mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod)
def test_torchbind_no_init(self):
with self.assertRaisesRegex(RuntimeError, 'torch::init'):
with self.assertRaisesRegex(RuntimeError, "torch::init"):
x = torch.classes._TorchScriptTesting._NoInit()
def test_profiler_custom_op(self):
@ -401,17 +415,17 @@ class TestTorchbind(JitTestCase):
found_event = False
for e in prof.function_events:
if e.name == '_TorchScriptTesting::take_an_instance':
if e.name == "_TorchScriptTesting::take_an_instance":
found_event = True
self.assertTrue(found_event)
def test_torchbind_getattr(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
self.assertEqual(None, getattr(foo, 'bar', None))
self.assertEqual(None, getattr(foo, "bar", None))
def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
with self.assertRaisesRegex(AttributeError, "does not have a field"):
foo.bar
def test_lambda_as_constructor(self):

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,24 @@
# Owner(s): ["oncall: jit"]
import io
import os
import sys
import io
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestTypeSharing(JitTestCase):
def assertSameType(self, m1, m2):
@ -42,6 +45,7 @@ class TestTypeSharing(JitTestCase):
def forward(self, x):
return x
a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.rand(2, 3)
@ -53,6 +57,7 @@ class TestTypeSharing(JitTestCase):
"""
Types should be shared even if attribute values differ
"""
class M(torch.nn.Module):
def __init__(self, a, b, c):
super().__init__()
@ -62,6 +67,7 @@ class TestTypeSharing(JitTestCase):
def forward(self, x):
return x
a = torch.rand(2, 3)
b = torch.rand(2, 3)
c = torch.rand(2, 3)
@ -73,6 +79,7 @@ class TestTypeSharing(JitTestCase):
"""
Types should be shared for identical constant values, and different for different constant values
"""
class M(torch.nn.Module):
__constants__ = ["const"]
@ -111,6 +118,7 @@ class TestTypeSharing(JitTestCase):
"""
If submodules differ, the types should differ.
"""
class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2):
super().__init__()
@ -137,6 +145,7 @@ class TestTypeSharing(JitTestCase):
The same module with an `foo` as a parameter vs. attribute shouldn't
share types
"""
class M(torch.nn.Module):
def __init__(self, foo):
super().__init__()
@ -156,6 +165,7 @@ class TestTypeSharing(JitTestCase):
Even if everything about the module is the same, different originating
classes should prevent type sharing.
"""
class A(torch.nn.Module):
__constants__ = ["const"]
@ -192,6 +202,7 @@ class TestTypeSharing(JitTestCase):
"""
Mutating the value of an attribute should not change type sharing
"""
class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2):
super().__init__()
@ -214,6 +225,7 @@ class TestTypeSharing(JitTestCase):
"""
Assigning a new (python-only) attribute should not change type sharing
"""
class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2):
super().__init__()
@ -244,6 +256,7 @@ class TestTypeSharing(JitTestCase):
"""
Attributes whose type cannot be inferred should fail cleanly with nice hints
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -255,15 +268,16 @@ class TestTypeSharing(JitTestCase):
return self.foo
m = M()
with self.assertRaisesRegexWithHighlight(RuntimeError,
"failed to convert Python type",
"self.foo"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to convert Python type", "self.foo"
):
torch.jit.script(m)
def test_script_function_attribute_different(self):
"""
Different functions passed in should lead to different types
"""
@torch.jit.script
def fn1(x):
return x + x
@ -317,6 +331,7 @@ class TestTypeSharing(JitTestCase):
"""
Same functions passed in should lead to same types
"""
@torch.jit.script
def fn(x):
return x + x
@ -338,6 +353,7 @@ class TestTypeSharing(JitTestCase):
"""
Different functions passed in should lead to different types
"""
def fn1(x):
return x + x
@ -361,6 +377,7 @@ class TestTypeSharing(JitTestCase):
"""
Same functions passed in should lead to same types
"""
def fn(x):
return x + x
@ -383,6 +400,7 @@ class TestTypeSharing(JitTestCase):
Since we can't guarantee that methods are the same between different
trace runs, tracing must always generate a unique type.
"""
class M(torch.nn.Module):
def forward(self, x, y):
if x.sum() > y.sum():
@ -429,8 +447,8 @@ class TestTypeSharing(JitTestCase):
def forward(self, x):
return self.traced(x)
a = M((torch.ones(1), ))
b = M((torch.zeros(1), ))
a = M((torch.ones(1),))
b = M((torch.zeros(1),))
self.assertDifferentType(a, b)
def test_loaded_modules_work(self):
@ -465,7 +483,6 @@ class TestTypeSharing(JitTestCase):
buffer.seek(0)
return torch.jit.script(Wrapper(torch.jit.load(buffer)))
a = package(AB())
a()
b = package(A())
@ -476,6 +493,7 @@ class TestTypeSharing(JitTestCase):
We should be able to differentiate between two ModuleDict instances
that have different keys but the same value types.
"""
class A(torch.nn.Module):
def forward(self, x):
return x
@ -488,9 +506,9 @@ class TestTypeSharing(JitTestCase):
def forward(self, x):
return x
a = Foo({'foo': A()})
b = Foo({'bar': A()})
c = Foo({'bar': A()})
a = Foo({"foo": A()})
b = Foo({"bar": A()})
c = Foo({"bar": A()})
self.assertDifferentType(a, b)
self.assertSameType(b, c)
@ -500,13 +518,16 @@ class TestTypeSharing(JitTestCase):
subclass that defines methods in its __init__ are not
shared.
"""
class A(torch.jit.ScriptModule):
def __init__(self, val):
super().__init__()
self.define(f"""
self.define(
f"""
def forward(self) -> int:
return {val}
""")
"""
)
one = A(1)
two = A(2)
@ -518,6 +539,7 @@ class TestTypeSharing(JitTestCase):
"""
Test that type sharing can be disabled.
"""
class A(torch.nn.Module):
def __init__(self, sub):
super().__init__()
@ -555,6 +577,7 @@ class TestTypeSharing(JitTestCase):
Test that types are shared if the exclusion of their
ignored attributes makes them equal.
"""
class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"]
@ -579,6 +602,7 @@ class TestTypeSharing(JitTestCase):
Test that types are not shared if the exclusion of their
ignored attributes makes them not equal.
"""
class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"]

View File

@ -1,26 +1,31 @@
# Owner(s): ["oncall: jit"]
from collections import namedtuple
from typing import Dict, Iterator, List, Optional, Tuple
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from textwrap import dedent
from jit.test_module_interface import TestModuleInterface # noqa: F401
import inspect
import os
import sys
from collections import namedtuple
from textwrap import dedent
from typing import Dict, Iterator, List, Optional, Tuple
import torch
import torch.testing._internal.jit_utils
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from jit.test_module_interface import TestModuleInterface # noqa: F401
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTypesAndAnnotation(JitTestCase):
def test_pep585_type(self):
@ -30,7 +35,7 @@ class TestTypesAndAnnotation(JitTestCase):
xl: list[tuple[torch.Tensor]] = []
xd: dict[str, int] = {}
xl.append((x,))
xd['foo'] = 1
xd["foo"] = 1
return xl.pop(), xd
self.checkScript(fn, [torch.randn(2, 2)])
@ -47,7 +52,7 @@ class TestTypesAndAnnotation(JitTestCase):
self.checkScript(fn, [torch.randn(2, 2)])
GG = namedtuple('GG', ['f', 'g'])
GG = namedtuple("GG", ["f", "g"])
class Foo(torch.nn.Module):
@torch.jit.ignore
@ -77,13 +82,17 @@ class TestTypesAndAnnotation(JitTestCase):
return x + 10
class M(torch.nn.Module):
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
def forward(
self, in_batch: Dict[str, Optional[torch.Tensor]]
) -> torch.Tensor:
self.dropout_modality(in_batch)
fn(in_batch)
return torch.tensor(1)
@torch.jit.ignore
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
def dropout_modality(
self, in_batch: Dict[str, Optional[torch.Tensor]]
) -> Dict[str, Optional[torch.Tensor]]:
return in_batch
sm = torch.jit.script(M())
@ -111,16 +120,17 @@ class TestTypesAndAnnotation(JitTestCase):
return my_arg + 10
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
@torch.jit.script
def other_fn(x):
return fn('2')
return fn("2")
def test_type_annotate_py3(self):
def fn():
a : List[int] = []
b : torch.Tensor = torch.ones(2, 2)
c : Optional[torch.Tensor] = None
d : Optional[torch.Tensor] = torch.ones(3, 4)
a: List[int] = []
b: torch.Tensor = torch.ones(2, 2)
c: Optional[torch.Tensor] = None
d: Optional[torch.Tensor] = torch.ones(3, 4)
for _ in range(10):
a.append(4)
c = torch.ones(2, 2)
@ -130,66 +140,88 @@ class TestTypesAndAnnotation(JitTestCase):
self.checkScript(fn, ())
def wrong_type():
wrong : List[int] = [0.5]
wrong: List[int] = [0.5]
return wrong
with self.assertRaisesRegex(RuntimeError, "List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements"):
with self.assertRaisesRegex(
RuntimeError,
"List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements",
):
torch.jit.script(wrong_type)
def test_optional_no_element_type_annotation(self):
"""
Test that using an optional with no contained types produces an error.
"""
def fn_with_comment(x: torch.Tensor) -> Optional:
return (x, x)
def annotated_fn(x: torch.Tensor) -> Optional:
return (x, x)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(fn_with_comment)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(annotated_fn)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
torch.jit.script(fn_with_comment)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
torch.jit.script(annotated_fn)
def test_tuple_no_element_type_annotation(self):
"""
Test that using a tuple with no contained types produces an error.
"""
def fn_with_comment(x: torch.Tensor) -> Tuple:
return (x, x)
def annotated_fn(x: torch.Tensor) -> Tuple:
return (x, x)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(fn_with_comment)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(annotated_fn)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
torch.jit.script(fn_with_comment)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
torch.jit.script(annotated_fn)
def test_ignoring_module_attributes(self):
"""
Test that module attributes can be ignored.
"""
class Sub(torch.nn.Module):
def forward(self, a: int) -> int:
return sum([a])
@ -229,10 +261,11 @@ class TestTypesAndAnnotation(JitTestCase):
mod = ModuleUsesIgnoredAttr(1)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"attribute was ignored during compilation", "self.sub"
):
scripted_mod = torch.jit.script(mod)
def test_ignoring_fn_with_nonscriptable_types(self):
class CFX:
def __init__(self, a: List[torch.Tensor]) -> None:
@ -246,7 +279,9 @@ class TestTypesAndAnnotation(JitTestCase):
return iter(self.a)
@torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
def __fx_create_arg__(
self, tracer: torch.fx.Tracer
) -> torch.fx.node.Argument:
# torch.fx classes are not scriptable
return tracer.create_node(
"call_function",
@ -257,35 +292,36 @@ class TestTypesAndAnnotation(JitTestCase):
torch.jit.script(CFX)
def test_unimported_type_resolution(self):
# verify fallback from the python resolver to the c++ resolver
@ torch.jit.script
@torch.jit.script
def fn(x):
# type: (number) -> number
return x + 1
FileCheck().check('Scalar').run(fn.graph)
FileCheck().check("Scalar").run(fn.graph)
def test_parser_bug(self):
def parser_bug(o: Optional[torch.Tensor]):
pass
def test_mismatched_annotation(self):
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
with self.assertRaisesRegex(RuntimeError, "annotated with type"):
@torch.jit.script
def foo():
x : str = 4
x: str = 4
return x
def test_reannotate(self):
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
with self.assertRaisesRegex(RuntimeError, "declare and annotate"):
@torch.jit.script
def foo():
x = 5
if 1 == 1:
x : Optional[int] = 7
x: Optional[int] = 7
def test_annotate_outside_init(self):
msg = "annotations on instance attributes must be declared in __init__"
@ -293,6 +329,7 @@ class TestTypesAndAnnotation(JitTestCase):
# Simple case
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script
class BadModule:
def __init__(self, x: int):
@ -303,6 +340,7 @@ class TestTypesAndAnnotation(JitTestCase):
# Type annotation in a loop
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script
class BadModuleLoop:
def __init__(self, x: int):
@ -324,8 +362,10 @@ class TestTypesAndAnnotation(JitTestCase):
def test_inferred_type_error_message(self):
inferred_type = torch._C.InferredType("ErrorReason")
with self.assertRaisesRegex(RuntimeError,
"Tried to get the type from an InferredType but the type is null."):
with self.assertRaisesRegex(
RuntimeError,
"Tried to get the type from an InferredType but the type is null.",
):
t = inferred_type.type()
with self.assertRaisesRegex(RuntimeError, "ErrorReason"):

View File

@ -2,12 +2,12 @@
import os
import sys
from collections import namedtuple
from typing import Dict, List, NamedTuple, Tuple
import torch
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing._internal.common_utils import IS_WINDOWS
from collections import namedtuple
from typing import List, Tuple, Dict, NamedTuple
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -20,14 +20,15 @@ if __name__ == "__main__":
"instead."
)
class TestTyping(JitTestCase):
def test_dict_in_not_in(self):
def test_in_dict(x):
# type: (Dict[str, int]) -> bool
return 'hi' in x
return "hi" in x
self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},))
self.checkScript(test_in_dict, ({'bye': 3},))
self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},))
self.checkScript(test_in_dict, ({"bye": 3},))
# Check evaluation order
@torch.jit.script
@ -57,8 +58,8 @@ class TestTyping(JitTestCase):
else:
return True
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, ))
self.checkScript(test_not_in_dict, ({"world": 2}, ))
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},))
self.checkScript(test_not_in_dict, ({"world": 2},))
def test_dict_tensor_key(a, t):
# type: (Dict[Tensor, int], Tensor) -> bool
@ -80,9 +81,12 @@ class TestTyping(JitTestCase):
l: List[int] = [1, 2, "foo", 3]
return l
with self.assertRaisesRegex(RuntimeError, "List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements"):
with self.assertRaisesRegex(
RuntimeError,
"List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements",
):
torch.jit.script(fn)
def test_dict_type_refinement_annotation_key_mismatch(self):
@ -92,10 +96,13 @@ class TestTyping(JitTestCase):
d: Dict[int, str] = dict(zip(l1, l2))
return d
with self.assertRaisesRegex(RuntimeError, "Dicts may only "
"contain homogeneous keys, but the "
"type of the first generated key "
r"was Union\[int, str\]"):
with self.assertRaisesRegex(
RuntimeError,
"Dicts may only "
"contain homogeneous keys, but the "
"type of the first generated key "
r"was Union\[int, str\]",
):
torch.jit.script(fn)
def test_dict_type_refinement_annotation_value_mismatch(self):
@ -105,28 +112,36 @@ class TestTyping(JitTestCase):
d: Dict[str, int] = dict(zip(l1, l2))
return d
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
r" `Dict\[str, int\]` did not match"
" the type of an actual value type"
r" `Union\[int, str\]`"):
with self.assertRaisesRegex(
RuntimeError,
"Dict type annotation"
r" `Dict\[str, int\]` did not match"
" the type of an actual value type"
r" `Union\[int, str\]`",
):
torch.jit.script(fn)
def test_dict_invalid_annotations(self):
# Check for invalid value type annotation
def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_value_type)
# Check for invalid key type annotation
def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_key_type)
# Check for invalid key and value type annotation
def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]):
def wrong_key_value_type(
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
):
return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_key_value_type)
@ -138,13 +153,16 @@ class TestTyping(JitTestCase):
_, y = t2
return x + y
t = torch.randn(2, 2), (1, torch.randn(2, 2)),
t = (
torch.randn(2, 2),
(1, torch.randn(2, 2)),
)
f(t, "hi")
graph = f.graph_for(t, "hi")
input_types = list(next(graph.inputs()).type().elements())
w = input_types[0]
self.assertEqual(input_types[0].kind(), 'TensorType')
self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType')
self.assertEqual(input_types[0].kind(), "TensorType")
self.assertEqual(input_types[1].elements()[1].kind(), "TensorType")
def test_tuple_io(self):
def stuff(x):
@ -165,8 +183,7 @@ class TestTyping(JitTestCase):
def foo():
return tuple(1, 2)
self.checkScriptRaisesRegex(foo, (), Exception,
"1 argument")
self.checkScriptRaisesRegex(foo, (), Exception, "1 argument")
def cant_infer_size():
return tuple([1, 2, 3]) # noqa: C409
@ -179,12 +196,14 @@ class TestTyping(JitTestCase):
# type: (int) -> Tuple[Tensor, Tensor]
a = (torch.ones(x), torch.zeros(x))
return a
self.checkScript(stuff2, (3,))
def test_list_io(self):
def stuff3(x):
# type: (List[int]) -> Tuple[Tensor, List[int]]
return torch.ones(x), x
self.checkScript(stuff3, ([3, 2],))
def test_bool_list_io(self):
@ -203,6 +222,7 @@ class TestTyping(JitTestCase):
# type: (Tuple[int, List[List[int]]]) -> int
x, y = z
return y[0][1]
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
def test_list_sum(self):
@ -215,12 +235,12 @@ class TestTyping(JitTestCase):
def fn2(x: List[bool]):
return sum(x)
self.checkScript(fn, ([1, 2, 3], ))
self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
self.checkScript(fn1, ([1, 2.8, 3], ))
self.checkScript(fn2, ([True, False, False], ))
self.checkScript(fn2, ([False, False, False], ))
self.checkScript(fn2, ([0, 1, 1, 0], ))
self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn1, ([1.0, 2.0, 3.0],))
self.checkScript(fn1, ([1, 2.8, 3],))
self.checkScript(fn2, ([True, False, False],))
self.checkScript(fn2, ([False, False, False],))
self.checkScript(fn2, ([0, 1, 1, 0],))
def test_list_unification(self):
def fn():
@ -254,7 +274,6 @@ class TestTyping(JitTestCase):
self.checkScript(self.get_sum_list_fn(), ([1],))
def test_sum_list_literal(self):
def sum_list():
# type: () -> int
sum = 0
@ -266,8 +285,8 @@ class TestTyping(JitTestCase):
self.checkScript(sum_list, ())
def test_sum_list_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
@torch.jit.script
def sum_list(a):
# type: (int) -> int
@ -280,14 +299,18 @@ class TestTyping(JitTestCase):
sum_list(1)
def test_list_iterables(self):
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
cu = torch.jit.CompilationUnit('''
with self.assertRaisesRegex(
RuntimeError, "List of iterables is not supported currently"
):
cu = torch.jit.CompilationUnit(
"""
def list_iterables(x):
for i, j in [2, 3, 4], [5, 6, 7]:
x += i
x += j
return x
''')
"""
)
def test_for_in_string(self):
def test_strings(x):
@ -352,36 +375,43 @@ class TestTyping(JitTestCase):
def test_dict_comprehension(self):
def fn():
return {i : chr(i + 65) for i in range(4)}
return {i: chr(i + 65) for i in range(4)}
self.checkScript(fn, ())
def test_dict_comprehension_with_type_annotation(self):
def fn():
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
d: Dict[int, str] = {i: chr(i + 65) for i in range(4)}
return d
self.checkScript(fn, ())
with self.assertRaisesRegex(RuntimeError, ""):
with self.assertRaisesRegex(AssertionError, "Expected Dict "
"type annotation for dict "
"comprehension, found "
"Tuple[int, str]"):
with self.assertRaisesRegex(
AssertionError,
"Expected Dict "
"type annotation for dict "
"comprehension, found "
"Tuple[int, str]",
):
@torch.jit.script
def fn():
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)}
return d
def test_dict_comprehension_scope(self):
def comprehension_can_access_outer_scope_variables():
lst = ["foo", "bar", "baz"]
return {l : len(l) for l in lst}
return {l: len(l) for l in lst}
self.checkScript(comprehension_can_access_outer_scope_variables, ())
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
@torch.jit.script
def outer_scope_cannot_access_comprehension_variables():
d = {i : chr(i + 65) for i in range(4)}
d = {i: chr(i + 65) for i in range(4)}
i = i + 1 # noqa: F821
def test_for_tuple_assign(self):
@ -402,22 +432,28 @@ class TestTyping(JitTestCase):
sum += a[1]
return sum
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),))
def test_single_starred_lhs(self):
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
' of another non-starred expression'):
cu = torch.jit.CompilationUnit('''
with self.assertRaisesRegex(
RuntimeError,
"A Starred expression may only appear on the lhs within the presence"
" of another non-starred expression",
):
cu = torch.jit.CompilationUnit(
"""
def single_starred_lhs(x):
a = (x, x, x)
*b, = a
return b
''')
"""
)
def test_singleton_tuple_unpack(self):
def foo(a):
b, = (a,)
(b,) = (a,)
return b + 1
self.checkScript(foo, (torch.rand(3),))
def test_tuple_assignments(self):
@ -441,7 +477,9 @@ class TestTyping(JitTestCase):
a[i], (x[i], b) = 1, (2, 3)
return a[i] + 1, x + 5, b
self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
self.checkScript(
subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)
)
def star_tuple_assign():
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
@ -455,7 +493,7 @@ class TestTyping(JitTestCase):
a[0] += 1
return a
with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"):
scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
def test_multiple_assign(self):
@ -505,7 +543,6 @@ class TestTyping(JitTestCase):
# type: (Optional[int]) -> int
return torch.jit._unwrap_optional(x)
@torch.jit.script
def fn(x):
# type: (int) -> int
@ -540,7 +577,7 @@ class TestTyping(JitTestCase):
# type: (Tuple[float, float]) -> int
return opt_list(x) + broadcast_opt_list(x)
self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4)
def test_optional_tuple(self):
def fn(x=None):
@ -556,10 +593,11 @@ class TestTyping(JitTestCase):
def test_namedtuple_redefine(self):
global _1, _2
_1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
_2 = namedtuple('GoogLeNetOutputs', ['different'])
_1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
_2 = namedtuple("GoogLeNetOutputs", ["different"])
with self.assertRaisesRegex(RuntimeError, r"redefine"):
with self.assertRaisesRegex(RuntimeError, r'redefine'):
@torch.jit.script
def foo(x, y):
# type: (_1, _2) -> _1
@ -567,7 +605,9 @@ class TestTyping(JitTestCase):
def test_namedtuple_py2(self):
global _GoogLeNetOutputs # see [local resolution in python]
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
_GoogLeNetOutputs = namedtuple(
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
)
@torch.jit.script
def foo(x):
@ -575,22 +615,27 @@ class TestTyping(JitTestCase):
return x
vals = torch.rand(3), torch.rand(4), torch.rand(5)
out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]))
out = foo(
_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])
)
self.assertEqual(out.logits, vals[0])
self.assertEqual(out.aux_logits2, vals[1])
self.assertEqual(out.aux_logits1, vals[2])
def test_namedtuple_good_error(self):
global _GoogLeNetOutputs # see [local resolution in python]
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
_GoogLeNetOutputs = namedtuple(
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
)
@torch.jit.script
def foo(x):
# type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
return x
with self.assertRaisesRegex(RuntimeError,
r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'):
with self.assertRaisesRegex(
RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)"
):
out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
def test_namedtuple_error_source_attribution(self):

View File

@ -3,22 +3,25 @@
import io
import os
import sys
import torch
from torch.testing import FileCheck
from enum import Enum
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUnion(JitTestCase):
"""
@ -57,9 +60,12 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[float, int\] but "
"instead found type str"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[float, int\] but "
"instead found type str",
):
scripted("1")
def test_union_with_collections(self):
@ -71,22 +77,31 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"Dict\[str, str\]"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"Dict\[str, str\]",
):
scripted({"foo": "bar", "baz": "qux"})
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"List\[str\]"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"List\[str\]",
):
scripted(["foo", "bar", "baz"])
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
"str"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
"str",
):
scripted("1")
def test_union_with_enum(self):
@ -104,16 +119,18 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[__torch__.jit.test_union."
r"Color, str\] but instead found "
"type int"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[__torch__.jit.test_union."
r"Color, str\] but instead found "
"type int",
):
scripted(1)
def test_union_in_class_constructor(self):
@torch.jit.script # noqa: B903
class A: # noqa: B903
class A: # noqa: B903
def __init__(self, x: Union[int, str]) -> None:
self.x = x
@ -125,9 +142,12 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
r" Union\[int, str\] but instead "
r"found type List\[str\]"):
with self.assertRaisesRegex(
RuntimeError,
"Expected a member of"
r" Union\[int, str\] but instead "
r"found type List\[str\]",
):
scripted(["foo", "bar", "baz"])
def test_union_return_type(self):
@ -171,7 +191,7 @@ class TestUnion(JitTestCase):
def test_union_variable_can_be_reassigned(self):
@torch.jit.script
def aux1(i: int):
return int(i ** 2)
return int(i**2)
@torch.jit.script
def aux2(s: str):
@ -225,8 +245,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union(float, int, str)") \
.run(s)
FileCheck().check("x : Union(float, int, str)").run(s)
def test_unions_of_a_single_argument_vanish(self):
@torch.jit.script
@ -235,8 +254,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : int") \
.run(s)
FileCheck().check("x : int").run(s)
def test_union_redundant_arguments_are_skipped(self):
@torch.jit.script
@ -245,8 +263,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union(int, str)") \
.run(s)
FileCheck().check("x : Union(int, str)").run(s)
def test_union_redundant_arguments_are_skipped_optional(self):
@torch.jit.script
@ -255,8 +272,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union(float, int, NoneType)") \
.run(s)
FileCheck().check("x : Union(float, int, NoneType)").run(s)
def test_union_redundant_arguments_are_skipped_subtyping(self):
@torch.jit.script
@ -265,8 +281,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union((int?, int), str)") \
.run(s)
FileCheck().check("x : Union((int?, int), str)").run(s)
def test_union_redundant_arguments_are_skipped_container(self):
@torch.jit.script
@ -275,8 +290,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union(float[], str[])") \
.run(s)
FileCheck().check("x : Union(float[], str[])").run(s)
def test_union_argument_order_is_ignored(self):
@torch.jit.script
@ -288,8 +302,7 @@ class TestUnion(JitTestCase):
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int, str)") \
.run(s)
FileCheck().check("x : Union(int, str)").run(s)
def test_union_argument_order_is_ignored_container(self):
@torch.jit.script
@ -301,8 +314,7 @@ class TestUnion(JitTestCase):
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int[], str[])") \
.run(s)
FileCheck().check("x : Union(int[], str[])").run(s)
def test_union_T_None_is_equivalent_to_optional_T(self):
@torch.jit.script
@ -366,9 +378,9 @@ class TestUnion(JitTestCase):
s = l.code
FileCheck().check("Union[int, NoneType, str]") \
.check("Union[int, NoneType, str]") \
.run(s)
FileCheck().check("Union[int, NoneType, str]").check(
"Union[int, NoneType, str]"
).run(s)
def test_union_subclasses_larger_union(self):
def fn() -> Union[int, str, torch.Tensor]:
@ -386,9 +398,12 @@ class TestUnion(JitTestCase):
x[1] = 2
return x[1]
with self.assertRaisesRegex(RuntimeError, "only int, float, "
"complex, Tensor, device and string keys "
"are supported"):
with self.assertRaisesRegex(
RuntimeError,
"only int, float, "
"complex, Tensor, device and string keys "
"are supported",
):
torch.jit.script(fn)
def test_union_as_dict_value(self):
@ -402,7 +417,6 @@ class TestUnion(JitTestCase):
def test_union_module_with_union_instance_variable(self):
class M(torch.nn.Module):
x: Union[int, str]
def __init__(self, x: Union[int, str]):
@ -413,7 +427,12 @@ class TestUnion(JitTestCase):
self.x = y
return self.x
self.checkModule(M(2,), (1,))
self.checkModule(
M(
2,
),
(1,),
)
self.checkModule(M("bar"), ("foo",))
def test_union_module_with_union_class_variable(self):
@ -508,9 +527,7 @@ class TestUnion(JitTestCase):
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
FileCheck().check_not("block0()").check_not("block1()").run(s)
def test_union_type_refinement_statically_true(self):
@torch.jit.script
@ -525,9 +542,7 @@ class TestUnion(JitTestCase):
s = fn.graph
# Check that we don't have any branching statements
FileCheck().check_not("block0()") \
.check_not("block1()") \
.run(s)
FileCheck().check_not("block0()").check_not("block1()").run(s)
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
def fn(x: Union[List[int], int]) -> int:
@ -556,7 +571,7 @@ class TestUnion(JitTestCase):
def test_union_type_refinement_internal_declaration(self):
def fn(flag: bool) -> str:
x: Union[int, str, None] = None
if (flag):
if flag:
y = "foo"
else:
y = 1
@ -589,9 +604,12 @@ class TestUnion(JitTestCase):
else:
return "bar"
with self.assertRaisesRegex(RuntimeError, "y is set to type str"
" in the true branch and type int "
"in the false branch"):
with self.assertRaisesRegex(
RuntimeError,
"y is set to type str"
" in the true branch and type int "
"in the false branch",
):
torch.jit.script(fn)
def test_union_branching_does_not_widen_existing_inferred_type(self):
@ -606,9 +624,12 @@ class TestUnion(JitTestCase):
else:
return "baz"
with self.assertRaisesRegex(RuntimeError, "previously had type "
"str but is now being assigned to a"
" value of type int"):
with self.assertRaisesRegex(
RuntimeError,
"previously had type "
"str but is now being assigned to a"
" value of type int",
):
torch.jit.script(fn)
def test_union_schema_matching_on_internal_type(self):
@ -645,8 +666,8 @@ class TestUnion(JitTestCase):
def test_union_memory_aliasing(self):
def fn():
x : List[torch.Tensor] = []
z : List[Optional[List[torch.Tensor]]] = []
x: List[torch.Tensor] = []
z: List[Optional[List[torch.Tensor]]] = []
z.append(x)
x_alias = z[0]
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
@ -682,203 +703,212 @@ class TestUnion(JitTestCase):
code = template.format(ann=ann, lhs=lhs)
with self.assertRaisesRegex(RuntimeError, msg):
cu = torch.jit.CompilationUnit(code, _frames_up=1)
string_frontend = getattr(cu, "fn") # noqa: B009
string_frontend = getattr(cu, "fn") # noqa: B009
def test_union_with_list_assignment(self):
template = dedent('''
template = dedent(
"""
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, List[torch.Tensor]):
x.append(torch.tensor(3))
return x
''')
"""
)
lhs = {"list_literal_empty" : "[]",
"list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",
"list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",
"list_literal_of_mixed" : "[torch.arange(5), 1]",
"list_comprehension_of_tensor" :
"[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
"list_comprehension_of_str" :
"[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
"list_comprehension_of_mixed" :
"[torch.add(1, x) for x in [torch.arange(5), 1]]"}
lhs = {
"list_literal_empty": "[]",
"list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
"list_literal_of_str": '["foo", "bar", "baz"]',
"list_literal_of_mixed": "[torch.arange(5), 1]",
"list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
"list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
"list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
}
"""
Union[List[str], List[torch.Tensor]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation")
self._assert_raises(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation",
)
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_tensor"])
self._assert_passes(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_tensor"],
)
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_str"])
self._assert_passes(
template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
)
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_mixed"],
"none of those types match the types of the"
" given list elements")
self._assert_raises(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_mixed"],
"none of those types match the types of the" " given list elements",
)
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_tensor"])
self._assert_passes(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_tensor"],
)
self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_str"])
self._assert_passes(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_str"],
)
# TODO: Support mixed list comprehensions
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")
self._assert_raises(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid",
)
"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["list_literal_empty"],
"Expected an Union type annotation with an "
"inner List type")
self._assert_raises(
template,
"Union[int, torch.Tensor]",
lhs["list_literal_empty"],
"Expected an Union type annotation with an " "inner List type",
)
self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_literal_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")
self._assert_raises(
template,
"Union[int, torch.Tensor]",
lhs["list_literal_of_tensor"],
"Expected an Union type annotation with an " "inner List type",
)
self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_comprehension_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")
self._assert_raises(
template,
"Union[int, torch.Tensor]",
lhs["list_comprehension_of_tensor"],
"Expected an Union type annotation with an " "inner List type",
)
"""
Union[List[torch.Tensor], int]
"""
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_empty"])
self._assert_passes(
template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
)
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_of_tensor"])
self._assert_passes(
template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
)
self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
self._assert_raises(
template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements",
)
self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_mixed"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
self._assert_raises(
template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_of_mixed"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements",
)
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_tensor"])
self._assert_passes(
template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_tensor"],
)
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")
self._assert_raises(
template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements",
)
# TODO(@ansley): Support mixed list comprehensions
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")
self._assert_raises(
template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid",
)
def test_union_with_dict_assignment(self):
template = dedent('''
template = dedent(
"""
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
x["foo"] = torch.tensor(3)
return x
''')
"""
)
lhs = {"dict_literal_empty" : "{}",
"dict_literal_of_str_tensor" :
"{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",
"dict_literal_of_str_int" :
"{\"foo\" : 1, \"bar\" : 2}",
"dict_literal_of_mixed" :
"{\"foo\" : torch.arange(3), \"bar\" : 2}",
"dict_comprehension_of_str_tensor" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",
"dict_comprehension_of_str_int" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [1, 2]}",
"dict_comprehension_of_mixed" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
"dict_keyword" :
"dict(foo=torch.arange(3), baz=torch.arange(5))",
"dict_keyword_with_iterable" :
"dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",
"dict_keyword_with_empty_iterable" :
"dict([])",
"dict_keyword_with_internal_aggregate_function" :
"dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",
"dict_keyword_with_mapping" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",
"dict_keyword_with_mapping_and_kwargs" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",
}
lhs = {
"dict_literal_empty": "{}",
"dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
"dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
"dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
"dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
"dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
zip(["foo", "bar"], [1, 2]}',
"dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
zip(["foo", "bar"], [torch.arange(3), 2])}',
"dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
"dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
"dict_keyword_with_empty_iterable": "dict([])",
"dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
"dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
"dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
}
"""
Union[Dict[str, torch.Tensor], Dict[str, int]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with an "
"inner Dict type")
self._assert_raises(
template,
"Union[List[str], List[torch.Tensor]]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with an " "inner Dict type",
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_tensor"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_tensor"],
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_int"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_int"],
)
self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values",
)
# TODO: String frontend does not support tuple unpacking
# https://github.com/pytorch/pytorch/issues/64096
@ -899,45 +929,57 @@ class TestUnion(JitTestCase):
# TODO(@ansley): Follow-up project needed for full type
# inference with dict keyword (supported for dict comprehension
# and dict literal already; should not be a blocker for anyone)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword"],
"full type inference is not yet supported")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword"],
"full type inference is not yet supported",
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_iterable"],
"full type inference is not yet supported")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_iterable"],
"full type inference is not yet supported",
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_empty_iterable"],
"full type inference is not yet supported")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_empty_iterable"],
"full type inference is not yet supported",
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping"],
"full type inference is not yet supported")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping"],
"full type inference is not yet supported",
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping_and_kwargs"],
"full type inference is not yet supported")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword_with_mapping_and_kwargs"],
"full type inference is not yet supported",
)
"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with "
"an inner Dict type")
self._assert_raises(
template,
"Union[int, torch.Tensor]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with " "an inner Dict type",
)
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_of_str_tensor"],
"Expected an Union type annotation with "
"an inner Dict type")
self._assert_raises(
template,
"Union[int, torch.Tensor]",
lhs["dict_literal_of_str_tensor"],
"Expected an Union type annotation with " "an inner Dict type",
)
# See above--string frontend does not support tuple unpacking
# self._assert_raises(template, "Union[int, torch.Tensor]",
@ -947,47 +989,61 @@ class TestUnion(JitTestCase):
"""
Union[Dict[str, torch.Tensor], int]
"""
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_empty"])
self._assert_passes(
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_tensor"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_tensor"],
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_int"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_int"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is",
)
self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_mixed"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is")
self._assert_raises(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_mixed"],
"Type annotation was inferred to be "
r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is",
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword"])
self._assert_passes(
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_iterable"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_iterable"],
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_empty_iterable"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_empty_iterable"],
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping"],
)
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping_and_kwargs"])
self._assert_passes(
template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping_and_kwargs"],
)
# See above--string frontend does not support tuple unpacking
# self._assert_passes(template,

View File

@ -2,19 +2,21 @@
import os
import sys
import unittest
import torch
import unittest
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# NOTE: FIXING FAILING TESTS
# If you are seeing a test failure from this file, congrats, you improved
@ -22,6 +24,7 @@ if __name__ == '__main__':
# the corresponding section in documentation that states the unsupported behavior.
# see: `jit_unsupported.rst`
class TestUnsupportedOps(JitTestCase):
def test_factory_ops_requires_grad_fail(self):
# Keyword argument {name} unknown is a JIT-only error message,
@ -32,31 +35,31 @@ class TestUnsupportedOps(JitTestCase):
def ones():
return torch.ones([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception,
"Keyword argument requires_grad unknown",
"torch.ones"):
with self.assertRaisesRegexWithHighlight(
Exception, "Keyword argument requires_grad unknown", "torch.ones"
):
torch.jit.script(ones)
def randn():
return torch.randn([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception,
"Keyword argument requires_grad unknown",
"torch.randn"):
with self.assertRaisesRegexWithHighlight(
Exception, "Keyword argument requires_grad unknown", "torch.randn"
):
torch.jit.script(randn)
def zeros():
return torch.zeros([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception,
"Keyword argument requires_grad unknown",
"torch.zeros"):
with self.assertRaisesRegexWithHighlight(
Exception, "Keyword argument requires_grad unknown", "torch.zeros"
):
torch.jit.script(zeros)
@unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")
def test_init_ops(self):
def calculate_gain():
return torch.nn.init.calculate_gain('leaky_relu', 0.2)
return torch.nn.init.calculate_gain("leaky_relu", 0.2)
def eye_():
return torch.nn.init.eye_(torch.zeros([2, 2]))
@ -71,9 +74,16 @@ class TestUnsupportedOps(JitTestCase):
return torch.nn.init.orthogonal_(torch.empty(3, 5))
def sparse():
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=.1)
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1)
for func in [calculate_gain, eye_, dirac_, kaiming_uniform_, orthogonal_, sparse]:
for func in [
calculate_gain,
eye_,
dirac_,
kaiming_uniform_,
orthogonal_,
sparse,
]:
# doesn't error in eager
func()
with self.assertRaisesRegex(Exception, ""):

View File

@ -3,20 +3,24 @@
import io
import os
import sys
import torch
import zipfile
from torch.testing import FileCheck
from typing import Union
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUpgraders(JitTestCase):
def _load_model_version(self, loaded_model):
@ -28,10 +32,10 @@ class TestUpgraders(JitTestCase):
# in a package between version 3 and 7.
# So we have to check for both.
try:
version = int(zipped_model.read('archive/version').decode("utf-8"))
version = int(zipped_model.read("archive/version").decode("utf-8"))
return version
except KeyError:
version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
return version
# TODO (tugsuu) We should ideally be generating this test cases.
@ -62,15 +66,23 @@ class TestUpgraders(JitTestCase):
upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
dummy_entry = torch._C._UpgraderEntry(
upgrader_bumped_version, upgrader_name, upgrader_schema
)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
torch._C._test_only_add_entry_to_op_version_map(
"aten::_test_serialization_subcmul", dummy_entry
)
map_after_test = torch._C._get_operator_version_map()
self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
torch._C._test_only_remove_entry_to_op_version_map(
"aten::_test_serialization_subcmul"
)
map_after_remove_test = torch._C._get_operator_version_map()
self.assertTrue("aten::_test_serialization_subcmul" not in map_after_remove_test)
self.assertTrue(
"aten::_test_serialization_subcmul" not in map_after_remove_test
)
self.assertEqual(len(map_after_remove_test), len(map_before_test))
def test_populated_test_upgrader_graph(self):
@ -151,7 +163,7 @@ class TestUpgraders(JitTestCase):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
loaded_model = torch.jit.load(model_path)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
for a, b in sample_inputs:
output_with_step, output_without_step = loaded_model(a, b)
# when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100)
@ -161,7 +173,9 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 8)
def test_aten_linspace_out(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
)
loaded_model = torch.jit.load(model_path)
sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64)),
@ -169,7 +183,7 @@ class TestUpgraders(JitTestCase):
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
)
for (a, b, c) in sample_inputs:
for a, b, c in sample_inputs:
output = loaded_model(a, b, c)
# when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100)
@ -181,7 +195,7 @@ class TestUpgraders(JitTestCase):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
loaded_model = torch.jit.load(model_path)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs:
for a, b in sample_inputs:
output_with_step, output_without_step = loaded_model(a, b)
# when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100)
@ -191,7 +205,9 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 9)
def test_aten_logspace_out(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
)
loaded_model = torch.jit.load(model_path)
sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64)),
@ -199,7 +215,7 @@ class TestUpgraders(JitTestCase):
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
)
for (a, b, c) in sample_inputs:
for a, b, c in sample_inputs:
output = loaded_model(a, b, c)
# when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100)
@ -208,21 +224,36 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 9)
def test_aten_test_serialization(self):
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
model_path = (
pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
)
# add test version entry to the version map
upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
dummy_entry = torch._C._UpgraderEntry(
upgrader_bumped_version, upgrader_name, upgrader_schema
)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
torch._C._test_only_add_entry_to_op_version_map(
"aten::_test_serialization_subcmul", dummy_entry
)
# add test upgrader in the upgraders map
@torch.jit.script
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
def _test_serialization_subcmul_0_2(
self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2
) -> torch.Tensor:
return other - (self * alpha)
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
torch._C._test_only_populate_upgraders(
{
"_test_serialization_subcmul_0_2": str(
_test_serialization_subcmul_0_2.graph
)
}
)
# test if the server is able to find the test upgraders and apply to IR
loaded_model = torch.jit.load(model_path)
@ -238,11 +269,21 @@ class TestUpgraders(JitTestCase):
# we check by its' code because graph variable names
# can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
torch._C._test_only_remove_entry_to_op_version_map(
"aten::_test_serialization_subcmul"
)
torch._C._test_only_remove_upgraders(
{
"_test_serialization_subcmul_0_2": str(
_test_serialization_subcmul_0_2.graph
)
}
)
def test_aten_div_scalar_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
)
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
@ -254,11 +295,15 @@ class TestUpgraders(JitTestCase):
self.assertEqual(version, 4)
loaded_model_twice = torch.jit.load(buffer)
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
self.assertEqual(
loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0),
)
def test_aten_div_tensor_out_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
)
loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
@ -274,7 +319,9 @@ class TestUpgraders(JitTestCase):
self.assertEqual(loaded_model.code, loaded_model_twice.code)
def test_aten_full_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
)
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
@ -290,7 +337,9 @@ class TestUpgraders(JitTestCase):
self.assertEqual(loaded_model.code, loaded_model_twice.code)
def test_aten_full_out_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
)
loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
version = self._load_model_version(loaded_model)

View File

@ -1,12 +1,12 @@
# Owner(s): ["oncall: jit"]
import io
import os
import sys
import io
import torch
import warnings
from contextlib import redirect_stderr
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWarn(JitTestCase):
@ -30,12 +32,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once(self):
@torch.jit.script
@ -47,12 +46,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once_in_loop_func(self):
def w():
@ -67,12 +63,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_once_per_func(self):
def w1():
@ -90,12 +83,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_once_per_func_in_loop(self):
def w1():
@ -114,12 +104,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
@ -131,12 +118,9 @@ class TestWarn(JitTestCase):
fn()
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
@ -155,13 +139,10 @@ class TestWarn(JitTestCase):
foo()
bar()
FileCheck() \
.check_count(
str="UserWarning: I am warning you from foo",
count=1,
exactly=True) \
.check_count(
str="UserWarning: I am warning you from bar",
count=1,
exactly=True) \
.run(f.getvalue())
FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True
).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True
).run(
f.getvalue()
)

View File

@ -32,6 +32,7 @@ class TestWith(JitTestCase):
Check that with statements that use the 'as' keyword to bind expressions
to targets work as expected.
"""
@torch.jit.script
class Context:
"""
@ -189,6 +190,7 @@ class TestWith(JitTestCase):
Check that with statements that do not use the 'as' keyword to bind expressions
to targets work as expected.
"""
@torch.jit.script
class Context:
"""
@ -345,6 +347,7 @@ class TestWith(JitTestCase):
Check that exceptions thrown in the bodies of with-statements are
handled correctly.
"""
@torch.jit.script
class Context:
"""
@ -416,15 +419,21 @@ class TestWith(JitTestCase):
# checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
# not compile class types (of which Context, the context manager being used for this test
# is one).
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception(torch.randn(2), c)
self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception_nested(torch.randn(2), c)
self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception_fn_call(torch.randn(2), c)
self.assertEqual(c.count, 1)
@ -505,7 +514,9 @@ class TestWith(JitTestCase):
return x
def test_exit_incorrect_types(x: torch.Tensor, cm: ExitIncorrectTypes) -> torch.Tensor:
def test_exit_incorrect_types(
x: torch.Tensor, cm: ExitIncorrectTypes
) -> torch.Tensor:
with cm as _:
pass
@ -523,7 +534,9 @@ class TestWith(JitTestCase):
self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"__enter__ must have only one argument and one return value", "cm"
RuntimeError,
r"__enter__ must have only one argument and one return value",
"cm",
):
self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
@ -539,7 +552,9 @@ class TestWith(JitTestCase):
test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"must return an object", "\"not_object\""):
with self.assertRaisesRegexWithHighlight(
RuntimeError, r"must return an object", '"not_object"'
):
self.checkScript(test_enter_without_object, ())
def test_with_no_grad(self):
@ -603,6 +618,7 @@ class TestWith(JitTestCase):
Check that torch.autograd.profiler.record_function context manager is
torchscriptable.
"""
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.autograd.profiler.record_function("foo"):
# Nested record_function.

View File

@ -7,6 +7,7 @@ import torch._C
torch.ops.load_library("//caffe2:xnnpack_backend")
class TestXNNPackBackend(unittest.TestCase):
def test_xnnpack_constant_data(self):
class Module(torch.nn.Module):
@ -24,17 +25,19 @@ class TestXNNPackBackend(unittest.TestCase):
scripted_module,
{
"forward": {
"inputs" : [torch.randn(4, 4, 4)],
"outputs": [torch.randn(4, 4, 4)]
"inputs": [torch.randn(4, 4, 4)],
"outputs": [torch.randn(4, 4, 4)],
}
}
},
)
for i in range(0, 20):
sample_input = torch.randn(4, 4, 4)
actual_output = scripted_module(sample_input)
expected_output = lowered_module(sample_input)
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_lowering(self):
class Module(torch.nn.Module):
@ -45,13 +48,11 @@ class TestXNNPackBackend(unittest.TestCase):
faulty_compile_spec = {
"backward": {
"inputs" : [torch.zeros(1)],
"inputs": [torch.zeros(1)],
"outputs": [torch.zeros(1)],
}
}
error_msg = (
"method_compile_spec does not contain the \"forward\" key."
)
error_msg = 'method_compile_spec does not contain the "forward" key.'
with self.assertRaisesRegex(
RuntimeError,
@ -64,21 +65,21 @@ class TestXNNPackBackend(unittest.TestCase):
)
mismatch_compile_spec = {
"forward" : {
"inputs" : [torch.zeros(1), torch.zeros(1)],
"outputs" : [torch.zeros(1)]
"forward": {
"inputs": [torch.zeros(1), torch.zeros(1)],
"outputs": [torch.zeros(1)],
}
}
error_msg = ("method_compile_spec inputs do not match expected number of forward inputs")
error_msg = (
"method_compile_spec inputs do not match expected number of forward inputs"
)
with self.assertRaisesRegex(
RuntimeError,
error_msg,
):
_ = torch._C._jit_to_backend(
"xnnpack",
scripted_module,
mismatch_compile_spec
"xnnpack", scripted_module, mismatch_compile_spec
)
lowered = torch._C._jit_to_backend(
@ -86,10 +87,10 @@ class TestXNNPackBackend(unittest.TestCase):
scripted_module,
{
"forward": {
"inputs" : [torch.zeros(1)],
"inputs": [torch.zeros(1)],
"outputs": [torch.zeros(1)],
}
}
},
)
lowered(torch.zeros(1))
@ -113,14 +114,16 @@ class TestXNNPackBackend(unittest.TestCase):
add_module,
{
"forward": {
"inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()],
"outputs": [sample_output]
"inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()],
"outputs": [sample_output],
}
}
},
)
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_broadcasting(self):
class AddModule(torch.nn.Module):
@ -139,14 +142,16 @@ class TestXNNPackBackend(unittest.TestCase):
add_module,
{
"forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output]
"inputs": [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output],
}
}
},
)
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_unsupported(self):
class AddSpliceModule(torch.nn.Module):
@ -173,8 +178,8 @@ class TestXNNPackBackend(unittest.TestCase):
add_module,
{
"forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output]
"inputs": [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output],
}
}
},
)

View File

@ -1,23 +1,30 @@
import argparse
import os
import sys
import torch
# grab modules from test_jit_hooks.cpp
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from jit.test_hooks_modules import (
create_forward_tuple_input, create_module_forward_multiple_inputs,
create_module_forward_single_input, create_module_hook_return_nothing,
create_forward_tuple_input,
create_module_forward_multiple_inputs,
create_module_forward_single_input,
create_module_hook_return_nothing,
create_module_multiple_hooks_multiple_inputs,
create_module_multiple_hooks_single_input, create_module_no_forward_input,
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs,
create_module_multiple_hooks_single_input,
create_module_no_forward_input,
create_module_same_hook_repeated,
create_submodule_forward_multiple_inputs,
create_submodule_forward_single_input,
create_submodule_hook_return_nothing,
create_submodule_multiple_hooks_multiple_inputs,
create_submodule_multiple_hooks_single_input,
create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks)
create_submodule_to_call_directly_with_hooks,
)
# Create saved modules for JIT forward hooks and pre-hooks
def main():
@ -30,23 +37,45 @@ def main():
save_name = options.export_script_module_to + "_"
tests = [
("test_submodule_forward_single_input", create_submodule_forward_single_input()),
("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()),
("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()),
("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()),
(
"test_submodule_forward_single_input",
create_submodule_forward_single_input(),
),
(
"test_submodule_forward_multiple_inputs",
create_submodule_forward_multiple_inputs(),
),
(
"test_submodule_multiple_hooks_single_input",
create_submodule_multiple_hooks_single_input(),
),
(
"test_submodule_multiple_hooks_multiple_inputs",
create_submodule_multiple_hooks_multiple_inputs(),
),
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
("test_module_forward_single_input", create_module_forward_single_input()),
("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()),
("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()),
("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()),
(
"test_module_forward_multiple_inputs",
create_module_forward_multiple_inputs(),
),
(
"test_module_multiple_hooks_single_input",
create_module_multiple_hooks_single_input(),
),
(
"test_module_multiple_hooks_multiple_inputs",
create_module_multiple_hooks_multiple_inputs(),
),
("test_module_hook_return_nothing", create_module_hook_return_nothing()),
("test_module_same_hook_repeated", create_module_same_hook_repeated()),
("test_module_no_forward_input", create_module_no_forward_input()),
("test_forward_tuple_input", create_forward_tuple_input()),
("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks())
(
"test_submodule_to_call_directly_with_hooks",
create_submodule_to_call_directly_with_hooks(),
),
]
for name, model in tests: