Enable UFMT on all of test/jit (#123623)

Partially addresses #123062

Ran lintrunner on:

- `test/jit`

with command:

```bash
lintrunner -a --take UFMT --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanhao Ji
2024-04-11 23:45:05 +00:00
committed by PyTorch MergeBot
parent d0ccf599cc
commit 604c9c5601
82 changed files with 6281 additions and 3210 deletions

View File

@ -1162,94 +1162,6 @@ exclude_patterns = [
'test/functorch/test_vmap.py', 'test/functorch/test_vmap.py',
'test/functorch/test_vmap_registrations.py', 'test/functorch/test_vmap_registrations.py',
'test/functorch/xfail_suggester.py', 'test/functorch/xfail_suggester.py',
'test/jit/__init__.py',
'test/jit/_imported_class_test/__init__.py',
'test/jit/_imported_class_test/bar.py',
'test/jit/_imported_class_test/foo.py',
'test/jit/_imported_class_test/very/__init__.py',
'test/jit/_imported_class_test/very/very/__init__.py',
'test/jit/_imported_class_test/very/very/nested.py',
'test/jit/fixtures_srcs/__init__.py',
'test/jit/fixtures_srcs/fixtures_src.py',
'test/jit/fixtures_srcs/generate_models.py',
'test/jit/fixtures_srcs/test_upgrader_models_generation.py',
'test/jit/myexception.py',
'test/jit/test_alias_analysis.py',
'test/jit/test_async.py',
'test/jit/test_aten_pow.py',
'test/jit/test_attr.py',
'test/jit/test_autodiff.py',
'test/jit/test_autodiff_subgraph_slicing.py',
'test/jit/test_await.py',
'test/jit/test_backend_nnapi.py',
'test/jit/test_backends.py',
'test/jit/test_batch_mm.py',
'test/jit/test_builtins.py',
'test/jit/test_class_type.py',
'test/jit/test_complex.py',
'test/jit/test_complexity.py',
'test/jit/test_convert_activation.py',
'test/jit/test_cuda.py',
'test/jit/test_custom_operators.py',
'test/jit/test_data_parallel.py',
'test/jit/test_dataclasses.py',
'test/jit/test_dce.py',
'test/jit/test_device_analysis.py',
'test/jit/test_dtype_analysis.py',
'test/jit/test_enum.py',
'test/jit/test_exception.py',
'test/jit/test_freezing.py',
'test/jit/test_functional_blocks.py',
'test/jit/test_fuser_common.py',
'test/jit/test_graph_rewrite_passes.py',
'test/jit/test_hash.py',
'test/jit/test_hooks.py',
'test/jit/test_hooks_modules.py',
'test/jit/test_ignorable_args.py',
'test/jit/test_ignore_context_manager.py',
'test/jit/test_isinstance.py',
'test/jit/test_jit_utils.py',
'test/jit/test_list_dict.py',
'test/jit/test_logging.py',
'test/jit/test_misc.py',
'test/jit/test_models.py',
'test/jit/test_module_apis.py',
'test/jit/test_module_containers.py',
'test/jit/test_module_interface.py',
'test/jit/test_modules.py',
'test/jit/test_op_decompositions.py',
'test/jit/test_optimize_for_mobile_preserve_debug_info.py',
'test/jit/test_parametrization.py',
'test/jit/test_pdt.py',
'test/jit/test_peephole.py',
'test/jit/test_profiler.py',
'test/jit/test_python_bindings.py',
'test/jit/test_python_builtins.py',
'test/jit/test_python_ir.py',
'test/jit/test_recursive_script.py',
'test/jit/test_remove_mutation.py',
'test/jit/test_save_load.py',
'test/jit/test_save_load_for_op_version.py',
'test/jit/test_script_profile.py',
'test/jit/test_scriptmod_ann.py',
'test/jit/test_slice.py',
'test/jit/test_sparse.py',
'test/jit/test_string_formatting.py',
'test/jit/test_symbolic_shape_analysis.py',
'test/jit/test_tensor_creation_ops.py',
'test/jit/test_tensor_methods.py',
'test/jit/test_torchbind.py',
'test/jit/test_tracer.py',
'test/jit/test_type_sharing.py',
'test/jit/test_types.py',
'test/jit/test_typing.py',
'test/jit/test_union.py',
'test/jit/test_unsupported_ops.py',
'test/jit/test_upgraders.py',
'test/jit/test_warn.py',
'test/jit/test_with.py',
'test/jit/xnnpack/test_xnnpack_delegate.py',
'test/jit_hooks/model.py',
'test/lazy/__init__.py', 'test/lazy/__init__.py',
'test/lazy/test_bindings.py', 'test/lazy/test_bindings.py',
'test/lazy/test_debug_util.py', 'test/lazy/test_debug_util.py',

View File

@ -1,4 +1,5 @@
import torch import torch
# This file contains definitions of script classes. # This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports # They are used by test_jit.py to test ScriptClass imports

View File

@ -1,5 +1,7 @@
import torch import torch
from . import bar from . import bar
# This file contains definitions of script classes. # This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports # They are used by test_jit.py to test ScriptClass imports

View File

@ -1,4 +1,5 @@
import torch import torch
# This file contains definitions of script classes. # This file contains definitions of script classes.
# They are used by test_jit.py to test ScriptClass imports # They are used by test_jit.py to test ScriptClass imports

View File

@ -1,6 +1,8 @@
import torch
from typing import Union from typing import Union
import torch
class TestVersionedDivTensorExampleV7(torch.nn.Module): class TestVersionedDivTensorExampleV7(torch.nn.Module):
def forward(self, a, b): def forward(self, a, b):
result_0 = a / b result_0 = a / b
@ -8,35 +10,52 @@ class TestVersionedDivTensorExampleV7(torch.nn.Module):
result_2 = a.div(b) result_2 = a.div(b)
return result_0, result_1, result_2 return result_0, result_1, result_2
class TestVersionedLinspaceV7(torch.nn.Module): class TestVersionedLinspaceV7(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]): def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
c = torch.linspace(a, b, steps=5) c = torch.linspace(a, b, steps=5)
d = torch.linspace(a, b) d = torch.linspace(a, b)
return c, d return c, d
class TestVersionedLinspaceOutV7(torch.nn.Module): class TestVersionedLinspaceOutV7(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor): def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.linspace(a, b, out=out) return torch.linspace(a, b, out=out)
class TestVersionedLogspaceV8(torch.nn.Module): class TestVersionedLogspaceV8(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]): def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
c = torch.logspace(a, b, steps=5) c = torch.logspace(a, b, steps=5)
d = torch.logspace(a, b) d = torch.logspace(a, b)
return c, d return c, d
class TestVersionedLogspaceOutV8(torch.nn.Module): class TestVersionedLogspaceOutV8(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor): def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.logspace(a, b, out=out) return torch.logspace(a, b, out=out)
class TestVersionedGeluV9(torch.nn.Module): class TestVersionedGeluV9(torch.nn.Module):
def forward(self, x): def forward(self, x):
return torch._C._nn.gelu(x) return torch._C._nn.gelu(x)
class TestVersionedGeluOutV9(torch.nn.Module): class TestVersionedGeluOutV9(torch.nn.Module):
def forward(self, x): def forward(self, x):
out = torch.zeros_like(x) out = torch.zeros_like(x)
return torch._C._nn.gelu(x, out=out) return torch._C._nn.gelu(x, out=out)
class TestVersionedRandomV10(torch.nn.Module): class TestVersionedRandomV10(torch.nn.Module):
def forward(self, x): def forward(self, x):
out = torch.zeros_like(x) out = torch.zeros_like(x)

View File

@ -6,9 +6,11 @@ from pathlib import Path
from typing import Set from typing import Set
import torch import torch
# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders. # Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403 from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -105,28 +107,41 @@ ALL_MODULES = {
Get the path to `test/jit/fixtures`, where all test models for operator changes Get the path to `test/jit/fixtures`, where all test models for operator changes
(upgrader/downgrader) are stored (upgrader/downgrader) are stored
""" """
def get_fixtures_path() -> Path: def get_fixtures_path() -> Path:
pytorch_dir = Path(__file__).resolve().parents[3] pytorch_dir = Path(__file__).resolve().parents[3]
fixtures_path = pytorch_dir / "test" / "jit" / "fixtures" fixtures_path = pytorch_dir / "test" / "jit" / "fixtures"
return fixtures_path return fixtures_path
""" """
Get all models' name in `test/jit/fixtures` Get all models' name in `test/jit/fixtures`
""" """
def get_all_models(model_directory_path: Path) -> Set[str]: def get_all_models(model_directory_path: Path) -> Set[str]:
files_in_fixtures = model_directory_path.glob('**/*') files_in_fixtures = model_directory_path.glob("**/*")
all_models_from_fixtures = [fixture.stem for fixture in files_in_fixtures if fixture.is_file()] all_models_from_fixtures = [
fixture.stem for fixture in files_in_fixtures if fixture.is_file()
]
return set(all_models_from_fixtures) return set(all_models_from_fixtures)
""" """
Check if a given model already exist in `test/jit/fixtures` Check if a given model already exist in `test/jit/fixtures`
""" """
def model_exist(model_file_name: str, all_models: Set[str]) -> bool: def model_exist(model_file_name: str, all_models: Set[str]) -> bool:
return model_file_name in all_models return model_file_name in all_models
""" """
Get the operator list given a module Get the operator list given a module
""" """
def get_operator_list(script_module: torch) -> Set[str]: def get_operator_list(script_module: torch) -> Set[str]:
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -134,21 +149,25 @@ def get_operator_list(script_module: torch) -> Set[str]:
operator_list = _export_operator_list(mobile_module) operator_list = _export_operator_list(mobile_module)
return operator_list return operator_list
""" """
Get the output model operator version, given a module Get the output model operator version, given a module
""" """
def get_output_model_version(script_module: torch.nn.Module) -> int: def get_output_model_version(script_module: torch.nn.Module) -> int:
buffer = io.BytesIO() buffer = io.BytesIO()
torch.jit.save(script_module, buffer) torch.jit.save(script_module, buffer)
buffer.seek(0) buffer.seek(0)
zipped_model = zipfile.ZipFile(buffer) zipped_model = zipfile.ZipFile(buffer)
try: try:
version = int(zipped_model.read('archive/version').decode("utf-8")) version = int(zipped_model.read("archive/version").decode("utf-8"))
return version return version
except KeyError: except KeyError:
version = int(zipped_model.read('archive/.data/version').decode("utf-8")) version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
return version return version
""" """
Loop through all test modules. If the corresponding model doesn't exist in Loop through all test modules. If the corresponding model doesn't exist in
`test/jit/fixtures`, generate one. For the following reason, a model won't be exported: `test/jit/fixtures`, generate one. For the following reason, a model won't be exported:
@ -165,6 +184,8 @@ likely this script is running with the commit to make the change.
3. The model already exists in `test/jit/fixtures`. 3. The model already exists in `test/jit/fixtures`.
""" """
def generate_models(model_directory_path: Path): def generate_models(model_directory_path: Path):
all_models = get_all_models(model_directory_path) all_models = get_all_models(model_directory_path)
for a_module, expect_operator in ALL_MODULES.items(): for a_module, expect_operator in ALL_MODULES.items():
@ -176,13 +197,17 @@ def generate_models(model_directory_path: Path):
"The module %s " "The module %s "
"is not a torch.nn.module instance. " "is not a torch.nn.module instance. "
"Please ensure it's a subclass of torch.nn.module in fixtures_src.py" "Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
"and it's registered as an instance in ALL_MODULES in generated_models.py", torch_module_name) "and it's registered as an instance in ALL_MODULES in generated_models.py",
torch_module_name,
)
# The corresponding model name is: test_versioned_div_tensor_example_v4 # The corresponding model name is: test_versioned_div_tensor_example_v4
model_name = ''.join([ model_name = "".join(
'_' + char.lower() if char.isupper() else char for char in torch_module_name [
]).lstrip('_') "_" + char.lower() if char.isupper() else char
for char in torch_module_name
]
).lstrip("_")
# Some models may not compile anymore, so skip the ones # Some models may not compile anymore, so skip the ones
# that already has pt file for them. # that already has pt file for them.
@ -199,7 +224,10 @@ def generate_models(model_directory_path: Path):
logger.error( logger.error(
"Actual model version %s " "Actual model version %s "
"is equal or larger than %s + 1. " "is equal or larger than %s + 1. "
"Please run the script before the commit to change operator.", actual_model_version, current_operator_version) "Please run the script before the commit to change operator.",
actual_model_version,
current_operator_version,
)
continue continue
actual_operator_list = get_operator_list(script_module) actual_operator_list = get_operator_list(script_module)
@ -207,16 +235,23 @@ def generate_models(model_directory_path: Path):
logger.error( logger.error(
"The model includes operator: %s, " "The model includes operator: %s, "
"however it doesn't cover the operator %s." "however it doesn't cover the operator %s."
"Please ensure the output model includes the tested operator.", actual_operator_list, expect_operator) "Please ensure the output model includes the tested operator.",
actual_operator_list,
expect_operator,
)
continue continue
export_model_path = str(model_directory_path / (str(model_name) + ".ptl")) export_model_path = str(model_directory_path / (str(model_name) + ".ptl"))
script_module._save_for_lite_interpreter(export_model_path) script_module._save_for_lite_interpreter(export_model_path)
logger.info("Generating model %s and it's save to %s", model_name, export_model_path) logger.info(
"Generating model %s and it's save to %s", model_name, export_model_path
)
def main() -> None: def main() -> None:
model_directory_path = get_fixtures_path() model_directory_path = get_fixtures_path()
generate_models(model_directory_path) generate_models(model_directory_path)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -2,7 +2,7 @@
import torch import torch
from test.jit.fixtures_srcs.generate_models import ALL_MODULES from test.jit.fixtures_srcs.generate_models import ALL_MODULES
from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_utils import run_tests, TestCase
class TestUpgraderModelGeneration(TestCase): class TestUpgraderModelGeneration(TestCase):
@ -14,7 +14,9 @@ class TestUpgraderModelGeneration(TestCase):
f"The module {module_name} " f"The module {module_name} "
f"is not a torch.nn.module instance. " f"is not a torch.nn.module instance. "
f"Please ensure it's a subclass of torch.nn.module in fixtures_src.py" f"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
f"and it's registered as an instance in ALL_MODULES in generated_models.py") f"and it's registered as an instance in ALL_MODULES in generated_models.py",
)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests() run_tests()

View File

@ -3,5 +3,7 @@ Define exceptions used in test_exception.py. We define them in a
separate file on purpose to make sure the fully qualified exception class name separate file on purpose to make sure the fully qualified exception class name
is captured correctly in suce cases. is captured correctly in suce cases.
""" """
class MyKeyError(KeyError): class MyKeyError(KeyError):
pass pass

View File

@ -1,15 +1,18 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch
from torch._C import parse_ir
from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from torch._C import parse_ir
import torch
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestAliasAnalysis(JitTestCase): class TestAliasAnalysis(JitTestCase):
def test_becomes_wildcard_annotations(self): def test_becomes_wildcard_annotations(self):
@ -26,9 +29,13 @@ class TestAliasAnalysis(JitTestCase):
alias_db = graph.alias_db() alias_db = graph.alias_db()
split_node = graph.findNode("aten::split") split_node = graph.findNode("aten::split")
# split input enters wildcard set, list initalized as containing wildcard set # split input enters wildcard set, list initalized as containing wildcard set
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())) self.assertTrue(
alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())
)
# because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs) # because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))) self.assertTrue(
alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))
)
def test_nested_list_construct_not_wildcard(self): def test_nested_list_construct_not_wildcard(self):
@torch.jit.script @torch.jit.script
@ -42,7 +49,9 @@ class TestAliasAnalysis(JitTestCase):
ten_construct = graph.findNode("aten::rand").output() ten_construct = graph.findNode("aten::rand").output()
output = next(graph.outputs()) output = next(graph.outputs())
self.assertTrue(alias_db.may_contain_alias(ten_construct, output)) self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
self.assertFalse(alias_db.may_contain_alias(next(graph.inputs()), ten_construct)) self.assertFalse(
alias_db.may_contain_alias(next(graph.inputs()), ten_construct)
)
def test_recursive_calls(self): def test_recursive_calls(self):
@torch.jit.script @torch.jit.script
@ -108,7 +117,9 @@ class TestAliasAnalysis(JitTestCase):
class MultiTmpFile: class MultiTmpFile:
def __init__(self, N): def __init__(self, N):
self.N = N self.N = N
self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)] self.ctxs = [
TemporaryFileName(mode="w", suffix=".py") for _ in range(N)
]
def __enter__(self): def __enter__(self):
return [x.__enter__() for x in self.ctxs] return [x.__enter__() for x in self.ctxs]

View File

@ -3,18 +3,20 @@
import os import os
import sys import sys
from typing import Any, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Any, Tuple
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
from typing import List from typing import List
from torch import Tensor from torch import Tensor
from torch.jit import Future from torch.jit import Future
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
class TestAsync(JitTestCase): class TestAsync(JitTestCase):
def test_async_python(self): def test_async_python(self):
@ -51,8 +53,7 @@ class TestAsync(JitTestCase):
futures = torch.jit.annotate(List[Future[List[Tensor]]], []) futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
for _ in range(3): for _ in range(3):
future = torch.jit.annotate( future = torch.jit.annotate(
Future[List[Tensor]], Future[List[Tensor]], torch.jit.fork(foo, x)
torch.jit.fork(foo, x)
) )
futures.append(future) futures.append(future)
@ -85,7 +86,7 @@ class TestAsync(JitTestCase):
def test_async_script_capture(self): def test_async_script_capture(self):
class Mod(torch.jit.ScriptModule): class Mod(torch.jit.ScriptModule):
__constants__ = ['const'] __constants__ = ["const"]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -139,7 +140,10 @@ class TestAsync(JitTestCase):
def test_async_script_no_script_mod(self): def test_async_script_no_script_mod(self):
x = torch.rand(3, 4) x = torch.rand(3, 4)
with self.assertRaisesRegexWithHighlight(RuntimeError, 'cannot call a value', 'torch.jit._fork(x'): with self.assertRaisesRegexWithHighlight(
RuntimeError, "cannot call a value", "torch.jit._fork(x"
):
@torch.jit.script @torch.jit.script
def wait_script(x): def wait_script(x):
fut = torch.jit._fork(x) fut = torch.jit._fork(x)
@ -213,7 +217,7 @@ class TestAsync(JitTestCase):
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)) lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
]: ]:
for wrapper in [ for wrapper in [
func, func,
@ -234,8 +238,8 @@ class TestAsync(JitTestCase):
return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)) return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
for wrapper in [ for wrapper in [
foo_script_args, foo_script_args,
foo_script_kwargs, foo_script_kwargs,
]: ]:
self.assertEqual(wrapper(x1, x2), y_hat) self.assertEqual(wrapper(x1, x2), y_hat)
self.assertEqual(wrapper(x1, x2=x2), y_hat) self.assertEqual(wrapper(x1, x2=x2), y_hat)
@ -255,7 +259,9 @@ class TestAsync(JitTestCase):
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
@torch.jit.script_method @torch.jit.script_method
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: def forward(
self, x: Tensor
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
future1 = torch.jit._fork(self.traced, x) future1 = torch.jit._fork(self.traced, x)
future2 = torch.jit._fork(torch.neg, x) future2 = torch.jit._fork(torch.neg, x)
@ -284,10 +290,16 @@ class TestAsync(JitTestCase):
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
# Make sure we have forks # Make sure we have forks
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2) self.assertGraphContainsExactly(
module.graph, kind="prim::fork", num_kind_nodes=2
)
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1) self.assertGraphContainsExactly(
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True) module.graph, kind="aten::neg", num_kind_nodes=1
)
self.assertGraphContainsExactly(
module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
)
y = torch.neg(x) y = torch.neg(x)
self.assertEqual(module(x), (y, y, y, y, x, x)) self.assertEqual(module(x), (y, y, y, y, x, x))
@ -311,19 +323,23 @@ class TestAsync(JitTestCase):
return torch.jit._wait(fut) return torch.jit._wait(fut)
# no future # no future
error_msg = 'The size.*must match the size of tensor' error_msg = "The size.*must match the size of tensor"
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'x.t() + x'): with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
foo(x) foo(x)
# one future # one future
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'torch.jit._fork(foo, x'): with self.assertRaisesRegexWithHighlight(
Exception, error_msg, "torch.jit._fork(foo, x"
):
wait_script(x) wait_script(x)
# two futures with a different error # two futures with a different error
x = torch.rand(3, 4, 5) x = torch.rand(3, 4, 5)
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
'expects a tensor with <= 2 dimensions', Exception,
'torch.jit._fork(wait_script, x'): "expects a tensor with <= 2 dimensions",
"torch.jit._fork(wait_script, x",
):
wait_script_nest(x) wait_script_nest(x)
def test_async_grad_guard_with_grad(self): def test_async_grad_guard_with_grad(self):
@ -381,9 +397,15 @@ class TestAsync(JitTestCase):
x = torch.rand(3, 4) x = torch.rand(3, 4)
self.assertEqual(fn(x), traced(x)) self.assertEqual(fn(x), traced(x))
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1) self.assertGraphContainsExactly(
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1) traced.graph, kind="prim::fork", num_kind_nodes=1
self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True) )
self.assertGraphContainsExactly(
traced.graph, kind="aten::wait", num_kind_nodes=1
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
)
def test_trace_fork_wait_leaking(self): def test_trace_fork_wait_leaking(self):
my_list = [] my_list = []
@ -397,9 +419,13 @@ class TestAsync(JitTestCase):
val = torch.jit._wait(fut) val = torch.jit._wait(fut)
return my_list[0] return my_list[0]
with self.assertRaisesRegexWithHighlight(RuntimeError, 'did not have observable data dependence with trace inputs; ' with self.assertRaisesRegexWithHighlight(
'this probably indicates your program cannot be understood ' RuntimeError,
'by the tracer.', ''): "did not have observable data dependence with trace inputs; "
"this probably indicates your program cannot be understood "
"by the tracer.",
"",
):
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False) traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
def test_trace_fork_wait_inline(self): def test_trace_fork_wait_inline(self):
@ -413,9 +439,15 @@ class TestAsync(JitTestCase):
traced = torch.jit.trace(fn, (torch.rand(3, 4),)) traced = torch.jit.trace(fn, (torch.rand(3, 4),))
torch._C._jit_pass_inline_fork_wait(traced.graph) torch._C._jit_pass_inline_fork_wait(traced.graph)
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0) self.assertGraphContainsExactly(
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0) traced.graph, kind="prim::fork", num_kind_nodes=0
self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2) )
self.assertGraphContainsExactly(
traced.graph, kind="aten::wait", num_kind_nodes=0
)
self.assertGraphContainsExactly(
traced.graph, kind="aten::add", num_kind_nodes=2
)
def test_trace_fork_wait_list_modulecalls(self): def test_trace_fork_wait_list_modulecalls(self):
def add_one(input): def add_one(input):
@ -472,7 +504,10 @@ class TestAsync(JitTestCase):
self.checkTrace(TestModule(), (torch.randn(5, 5),)) self.checkTrace(TestModule(), (torch.randn(5, 5),))
def test_no_future_subtype_message(self): def test_no_future_subtype_message(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, 'Future without a contained type', ''): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Future without a contained type", ""
):
@torch.jit.script @torch.jit.script
def forward(self, x): def forward(self, x):
futs = torch.jit.annotate(List[torch.jit.Future], []) futs = torch.jit.annotate(List[torch.jit.Future], [])
@ -481,6 +516,7 @@ class TestAsync(JitTestCase):
""" """
Test that futures subtype each other properly. Test that futures subtype each other properly.
""" """
# Successful subtyping. # Successful subtyping.
def returns_int(x: int) -> int: def returns_int(x: int) -> int:
return x + x + 1 return x + x + 1
@ -495,10 +531,11 @@ class TestAsync(JitTestCase):
# Unsuccessful subtyping. # Unsuccessful subtyping.
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, RuntimeError,
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]", r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
"fut = returns_future_float(x" "fut = returns_future_float(x",
): ):
def returns_future_float(x: int) -> torch.jit.Future[float]: def returns_future_float(x: int) -> torch.jit.Future[float]:
return torch.jit._fork(returns_int, (x)) return torch.jit._fork(returns_int, (x))
@ -508,8 +545,9 @@ class TestAsync(JitTestCase):
return fut.wait() return fut.wait()
if __name__ == "__main__":
if __name__ == '__main__': raise RuntimeError(
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n" "\tpython test/test_jit.py TESTNAME\n\n"
"instead.") "instead."
)

View File

@ -3,35 +3,40 @@
import torch import torch
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
class TestAtenPow(TestCase): class TestAtenPow(TestCase):
def test_aten_pow_zero_negative_exponent(self): def test_aten_pow_zero_negative_exponent(self):
''' """
1. Testing a = int, b = int 1. Testing a = int, b = int
''' """
@torch.jit.script @torch.jit.script
def fn_int_int(a: int, b: int): def fn_int_int(a: int, b: int):
return a ** b return a**b
# Existing correct behaviors of aten::pow # Existing correct behaviors of aten::pow
self.assertEqual(fn_int_int(2, 1), 2 ** 1) self.assertEqual(fn_int_int(2, 1), 2**1)
self.assertEqual(fn_int_int(2, 0), 2 ** 0) self.assertEqual(fn_int_int(2, 0), 2**0)
self.assertEqual(fn_int_int(2, -2), 2 ** (-2)) self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
self.assertEqual(fn_int_int(-2, 2), (-2) ** 2) self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
self.assertEqual(fn_int_int(-2, 0), (-2) ** 0) self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2)) self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1)) self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
self.assertEqual(fn_int_int(0, 2), 0 ** 1) self.assertEqual(fn_int_int(0, 2), 0**1)
self.assertEqual(fn_int_int(0, 0), 0 ** 0) self.assertEqual(fn_int_int(0, 0), 0**0)
# zero base and negative exponent case that should trigger RunTimeError # zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_int_int, 0, -2) self.assertRaises(RuntimeError, fn_int_int, 0, -2)
''' """
2. Testing a = int, b = float 2. Testing a = int, b = float
''' """
@torch.jit.script @torch.jit.script
def fn_int_float(a: int, b: float): def fn_int_float(a: int, b: float):
return a ** b return a**b
# Existing correct behaviors of aten::pow # Existing correct behaviors of aten::pow
self.assertEqual(fn_int_float(2, 2.5), 2 ** 2.5) self.assertEqual(fn_int_float(2, 2.5), 2**2.5)
self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5)) self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0)) self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0)) self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
@ -40,53 +45,57 @@ class TestAtenPow(TestCase):
self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0)) self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0)) self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0)) self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
self.assertEqual(fn_int_float(0, 2.0), 0 ** 2.0) self.assertEqual(fn_int_float(0, 2.0), 0**2.0)
self.assertEqual(fn_int_float(0, 0.5), 0 ** 0.5) self.assertEqual(fn_int_float(0, 0.5), 0**0.5)
self.assertEqual(fn_int_float(0, 0.0), 0 ** 0.0) self.assertEqual(fn_int_float(0, 0.0), 0**0.0)
self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0)) self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
# zero base and negative exponent case that should trigger RunTimeError # zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_int_float, 0, -2.5) self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)
''' """
3. Testing a = float, b = int 3. Testing a = float, b = int
''' """
@torch.jit.script @torch.jit.script
def fn_float_int(a: float, b: int): def fn_float_int(a: float, b: int):
return a ** b return a**b
# Existing correct behaviors of aten::pow # Existing correct behaviors of aten::pow
self.assertEqual(fn_float_int(2.5, 2), 2.5 ** 2) self.assertEqual(fn_float_int(2.5, 2), 2.5**2)
self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2)) self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0)) self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
self.assertEqual(fn_float_int(2.5, 0), 2.5 ** 0) self.assertEqual(fn_float_int(2.5, 0), 2.5**0)
self.assertEqual(fn_float_int(-2.5, 2), 2.5 ** 2) self.assertEqual(fn_float_int(-2.5, 2), 2.5**2)
self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2)) self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3)) self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0)) self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0) self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
self.assertEqual(fn_float_int(0.0, 2), 0 ** 2) self.assertEqual(fn_float_int(0.0, 2), 0**2)
self.assertEqual(fn_float_int(0.0, 0), 0 ** 0) self.assertEqual(fn_float_int(0.0, 0), 0**0)
self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0)) self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
# zero base and negative exponent case that should trigger RunTimeError # zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_float_int, 0.0, -2) self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)
''' """
4. Testing a = float, b = float 4. Testing a = float, b = float
''' """
@torch.jit.script @torch.jit.script
def fn_float_float(a: float, b: float): def fn_float_float(a: float, b: float):
return a ** b return a**b
# Existing correct behaviors of aten::pow # Existing correct behaviors of aten::pow
self.assertEqual(fn_float_float(2.5, 2.0), 2.5 ** 2.0) self.assertEqual(fn_float_float(2.5, 2.0), 2.5**2.0)
self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0)) self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0)) self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
self.assertEqual(fn_float_float(2.5, 0.0), 2.5 ** 0.0) self.assertEqual(fn_float_float(2.5, 0.0), 2.5**0.0)
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5 ** 2.0) self.assertEqual(fn_float_float(-2.5, 2.0), 2.5**2.0)
self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0)) self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0)) self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0)) self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0) self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
self.assertEqual(fn_float_float(0.0, 2.0), 0.0 ** 2.0) self.assertEqual(fn_float_float(0.0, 2.0), 0.0**2.0)
self.assertEqual(fn_float_float(0.0, 0.0), 0.0 ** 0.0) self.assertEqual(fn_float_float(0.0, 0.0), 0.0**0.0)
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0)) self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
# zero base and negative exponent case that should trigger RunTimeError # zero base and negative exponent case that should trigger RunTimeError
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0) self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)

View File

@ -1,20 +1,22 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
import torch
from typing import NamedTuple, Tuple from typing import NamedTuple, Tuple
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" if __name__ == "__main__":
"\tpython test/test_jit.py TESTNAME\n\n" raise RuntimeError(
"instead.") "This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGetDefaultAttr(JitTestCase): class TestGetDefaultAttr(JitTestCase):
def test_getattr_with_default(self): def test_getattr_with_default(self):
class A(torch.nn.Module): class A(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -22,7 +24,7 @@ class TestGetDefaultAttr(JitTestCase):
def forward(self, x): def forward(self, x):
y = getattr(self, "init_attr_val") # noqa: B009 y = getattr(self, "init_attr_val") # noqa: B009
w : list[float] = [1.0] w: list[float] = [1.0]
z = getattr(self, "missing", w) # noqa: B009 z = getattr(self, "missing", w) # noqa: B009
z.append(y) z.append(y)
return z return z
@ -32,7 +34,7 @@ class TestGetDefaultAttr(JitTestCase):
graph = torch.jit.script(A()).graph graph = torch.jit.script(A()).graph
# The "init_attr_val" attribute exists # The "init_attr_val" attribute exists
FileCheck().check("prim::GetAttr[name=\"init_attr_val\"]").run(graph) FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
# The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST # The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
FileCheck().check_not("missing").run(graph) FileCheck().check_not("missing").run(graph)
# instead the getattr call will emit the default value, which is a list with one float element # instead the getattr call will emit the default value, which is a list with one float element
@ -46,7 +48,11 @@ class TestGetDefaultAttr(JitTestCase):
y: torch.Tensor y: torch.Tensor
def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]: def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
return getattr(x, "x", "fdsa"), getattr(x, "y", torch.ones((3, 3))), getattr(x, "z", 7) return (
getattr(x, "x", "fdsa"),
getattr(x, "y", torch.ones((3, 3))),
getattr(x, "z", 7),
)
inp = MyTuple(x="test", y=torch.ones(3, 3) * 2) inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
ref = fn(inp) ref = fn(inp)

View File

@ -1,10 +1,11 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from typing import List
import torch import torch
from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.testing._internal.common_utils import skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from typing import List
@skipIfTorchDynamo() @skipIfTorchDynamo()
@ -119,7 +120,6 @@ class TestAutodiffJit(JitTestCase):
self.assertEqual(y_s.requires_grad, y.requires_grad) self.assertEqual(y_s.requires_grad, y.requires_grad)
self.assertEqual(z_s.requires_grad, z.requires_grad) self.assertEqual(z_s.requires_grad, z.requires_grad)
def test_autodiff_requires_grad_nograd(self): def test_autodiff_requires_grad_nograd(self):
@torch.jit.ignore @torch.jit.ignore
def python_fn(x): def python_fn(x):

View File

@ -3,26 +3,38 @@
import os import os
import sys import sys
import unittest import unittest
from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, \
num_profiled_runs, enable_profiling_mode_for_profiling_tests
from torch.testing._internal.common_jit import check_against_reference
import torch import torch
from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
num_profiled_runs,
ProfilingMode,
)
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, disable_autodiff_subgraph_inlining from typing import List, Optional, Tuple
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.jit_utils import (
disable_autodiff_subgraph_inlining,
JitTestCase,
)
from typing import List, Tuple, Optional if __name__ == "__main__":
raise RuntimeError(
if __name__ == '__main__': "This test file is not meant to be run directly, use:\n\n"
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n"
"\tpython test/test_jit.py TESTNAME\n\n" "instead."
"instead.") )
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients") @unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
)
class TestAutodiffSubgraphSlicing(JitTestCase): class TestAutodiffSubgraphSlicing(JitTestCase):
# TODO: It is better if we can test directly on graphs instead of the current # TODO: It is better if we can test directly on graphs instead of the current
# end-to-end fashion. # end-to-end fashion.
@ -35,11 +47,17 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return ge.graph_for(*inputs) return ge.graph_for(*inputs)
def assertGraphSize(self, graph, size): def assertGraphSize(self, graph, size):
nodes = list(filter(lambda n: (n.kind() != "prim::BailOut" and nodes = list(
n.kind() != "prim::BailoutTemplate" and filter(
n.kind() != "prim::TypeCheck" and lambda n: (
n.kind() != "prim::RequiresGradCheck"), n.kind() != "prim::BailOut"
graph.nodes())) and n.kind() != "prim::BailoutTemplate"
and n.kind() != "prim::TypeCheck"
and n.kind() != "prim::RequiresGradCheck"
),
graph.nodes(),
)
)
self.assertEqual(len(list(nodes)), size) self.assertEqual(len(list(nodes)), size)
def test_chunk_constant_script_ad(self): def test_chunk_constant_script_ad(self):
@ -52,16 +70,21 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
with disable_autodiff_subgraph_inlining(): with disable_autodiff_subgraph_inlining():
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
output = func(input, profile_and_replay=True) output = func(input, profile_and_replay=True)
FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input)) FileCheck().check_not("prim::DifferentiableGraph").run(
func.graph_for(input)
)
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "This threshold is only valid for Profiling Executor") @unittest.skipIf(
GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"This threshold is only valid for Profiling Executor",
)
def test_diff_graph_inline_threshold(self): def test_diff_graph_inline_threshold(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
NUM_RUNS = 1 NUM_RUNS = 1
with num_profiled_runs(NUM_RUNS): with num_profiled_runs(NUM_RUNS):
@torch.jit.script @torch.jit.script
def foo(x): def foo(x):
# two nodes should be fused # two nodes should be fused
# see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49 # see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
return torch.sigmoid(torch.sigmoid(x)) return torch.sigmoid(torch.sigmoid(x))
@ -78,12 +101,16 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
bar(input) bar(input)
bar(input) bar(input)
self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(
self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0) foo.graph_for(input), "prim::DifferentiableGraph", 1
)
self.assertGraphContainsExactly(
bar.graph_for(input), "prim::DifferentiableGraph", 0
)
def test_bias_as_module_attr(self): def test_bias_as_module_attr(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, has_bias): def __init__(self, has_bias):
super().__init__() super().__init__()
@ -99,19 +126,40 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
scripted_no_bias(x, x) scripted_no_bias(x, x)
scripted_no_bias(x, x) scripted_no_bias(x, x)
has_bias = M(True) has_bias = M(True)
check_against_reference(self, scripted_no_bias, no_bias, lambda x: x, (x, x,), check_types=False) check_against_reference(
self,
scripted_no_bias,
no_bias,
lambda x: x,
(
x,
x,
),
check_types=False,
)
scripted_has_bias = torch.jit.script(has_bias) scripted_has_bias = torch.jit.script(has_bias)
scripted_has_bias(x, x) scripted_has_bias(x, x)
scripted_has_bias(x, x) scripted_has_bias(x, x)
scripted_has_bias(x, x) scripted_has_bias(x, x)
check_against_reference(self, scripted_has_bias, has_bias, lambda x: x, (x, x,), check_types=False) check_against_reference(
self,
scripted_has_bias,
has_bias,
lambda x: x,
(
x,
x,
),
check_types=False,
)
def test_constructed_bias(self): def test_constructed_bias(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
def method1(x, weight, b1, b2): def method1(x, weight, b1, b2):
bias = b1 * b2 bias = b1 * b2
return torch.nn.functional.linear(x, weight, bias) return torch.nn.functional.linear(x, weight, bias)
N = 10 N = 10
x = torch.rand(N, N, requires_grad=True) x = torch.rand(N, N, requires_grad=True)
weight = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True)
@ -119,35 +167,58 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
b2 = torch.rand(N, N, requires_grad=True) b2 = torch.rand(N, N, requires_grad=True)
scripted = self.checkScript(method1, (x, weight, b1, b2)) scripted = self.checkScript(method1, (x, weight, b1, b2))
# check_types requires last_graph on scripted to be set, so we just skip it # check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False) check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, b1, b2),
check_types=False,
)
def test_bias_as_arg(self): def test_bias_as_arg(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
def method1(x, weight, bias: Optional[torch.Tensor]): def method1(x, weight, bias: Optional[torch.Tensor]):
return torch.nn.functional.linear(x, weight, bias).relu() + 2 return torch.nn.functional.linear(x, weight, bias).relu() + 2
N = 10 N = 10
x = torch.rand(N, N, requires_grad=True) x = torch.rand(N, N, requires_grad=True)
weight = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True)
bias = None bias = None
scripted = self.checkScript(method1, (x, weight, bias)) scripted = self.checkScript(method1, (x, weight, bias))
# check_types requires last_graph on scripted to be set, so we just skip it # check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False) check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, bias),
check_types=False,
)
bias = torch.rand(N, N, requires_grad=True) bias = torch.rand(N, N, requires_grad=True)
scripted = self.checkScript(method1, (x, weight, bias)) scripted = self.checkScript(method1, (x, weight, bias))
# check_types requires last_graph on scripted to be set, so we just skip it # check_types requires last_graph on scripted to be set, so we just skip it
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False) check_against_reference(
self,
scripted,
method1,
lambda x: x,
(x, weight, bias),
check_types=False,
)
def test_requires_grad_for_tensor_list(self): def test_requires_grad_for_tensor_list(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
# output & var_list[0] should have requires_grad set to True # output & var_list[0] should have requires_grad set to True
def func(input0: torch.Tensor, input1: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: def func(
input0: torch.Tensor, input1: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
var_list = [input0, input1] var_list = [input0, input1]
var = torch.cat(var_list) var = torch.cat(var_list)
output = var + 1.0 output = var + 1.0
return output, var_list return output, var_list
jit_f = torch.jit.script(func) jit_f = torch.jit.script(func)
input0 = torch.randn((2,), requires_grad=True) input0 = torch.randn((2,), requires_grad=True)
input1 = torch.randn((2,)) input1 = torch.randn((2,))
@ -158,12 +229,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
assert output_ref[1][0].requires_grad == output[1][0].requires_grad assert output_ref[1][0].requires_grad == output[1][0].requires_grad
assert output_ref[1][1].requires_grad == output[1][1].requires_grad assert output_ref[1][1].requires_grad == output[1][1].requires_grad
@unittest.skip("disable until we property handle tensor lists with undefined gradients") @unittest.skip(
"disable until we property handle tensor lists with undefined gradients"
)
def test_differentiable_graph_ops_requires_grad(self): def test_differentiable_graph_ops_requires_grad(self):
x = torch.randn(8, 2, dtype=torch.float).requires_grad_() x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
y = torch.randn(8, 2, dtype=torch.float) y = torch.randn(8, 2, dtype=torch.float)
def t(x : torch.Tensor, y : torch.Tensor, flag : bool): def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
o = x + 1.0 o = x + 1.0
o1 = torch.relu(o) o1 = torch.relu(o)
o = y + 1.5 o = y + 1.5
@ -186,13 +259,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return o1, o2, o3, oo1, oo2, oo3 return o1, o2, o3, oo1, oo2, oo3
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
t_jit = torch.jit.script(t) t_jit = torch.jit.script(t)
jit_o = t_jit(x, y, False) jit_o = t_jit(x, y, False)
jit_o = t_jit(x, y, False) jit_o = t_jit(x, y, False)
o = t(x, y, False) o = t(x, y, False)
FileCheck().check("prim::DifferentiableGraph").run(t_jit.graph_for(x, y, False)) FileCheck().check("prim::DifferentiableGraph").run(
t_jit.graph_for(x, y, False)
)
# validate the differentiableGraphOps are marking proper requires_grad # validate the differentiableGraphOps are marking proper requires_grad
for oo, jit_oo in zip(o, jit_o): for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
@ -204,22 +278,28 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
self.assertEqual(oo, jit_oo) self.assertEqual(oo, jit_oo)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Simple Executor doesn't support gradients") @unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.PROFILING,
"Simple Executor doesn't support gradients",
)
def test_prune_grad(self): def test_prune_grad(self):
@torch.jit.script @torch.jit.script
def t(input, bias): def t(input, bias):
return torch.nn.functional.relu(input + bias) return torch.nn.functional.relu(input + bias)
input = torch.randn(2, 8, requires_grad=True) input = torch.randn(2, 8, requires_grad=True)
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
NUM_PROFILED_RUNS = 1 NUM_PROFILED_RUNS = 1
with num_profiled_runs(NUM_PROFILED_RUNS): with num_profiled_runs(NUM_PROFILED_RUNS):
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
for x in range(WARMUP): for x in range(WARMUP):
o = t(input, bias) o = t(input, bias)
o.sum().backward() o.sum().backward()
fwd_plan = list(t.get_debug_state().execution_plans.values())[0] fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
bwd_graph = list(fwd_plan.code.grad_executor_states()[0].execution_plans.values())[0].graph bwd_graph = list(
fwd_plan.code.grad_executor_states()[0].execution_plans.values()
)[0].graph
tup = next(bwd_graph.outputs()) tup = next(bwd_graph.outputs())
self.assertEqual(len(list(tup.node().inputs())), 1) self.assertEqual(len(list(tup.node().inputs())), 1)
@ -233,7 +313,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1) self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_simple_no_merge(self): def test_simple_no_merge(self):
# o: autodiff supported. x: not autodiff supported. # o: autodiff supported. x: not autodiff supported.
@ -245,8 +325,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
g_str = str(graph) g_str = str(graph)
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(g_str[0:g_str.find("return")]) FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) g_str[0 : g_str.find("return")]
)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_does_not_merge_unrelated(self): def test_does_not_merge_unrelated(self):
# o o # o o
@ -258,7 +340,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
self.assertGraphSize(graph, 3) self.assertGraphSize(graph, 3)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_merges_without_cycles(self): def test_merges_without_cycles(self):
# o --> o --> o # o --> o --> o
@ -273,7 +355,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphSize(graph, 1) self.assertGraphSize(graph, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_merges_dense(self): def test_merges_dense(self):
# o o # o o
@ -290,7 +372,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 2, 2) graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
self.assertGraphSize(graph, 2) self.assertGraphSize(graph, 2)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_does_not_create_cycles(self): def test_does_not_create_cycles(self):
# o --> x --> o # o --> x --> o
@ -303,7 +385,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
return c return c
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_merges_up(self): def test_merges_up(self):
# o --> x o # o --> x o
@ -317,8 +399,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
g_str = str(graph) g_str = str(graph)
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")]) FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_merges_down(self): def test_merges_down(self):
# o x --> o # o x --> o
@ -335,8 +417,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3 num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
# add moved down # add moved down
g_str = str(graph) g_str = str(graph)
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")]) FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
def test_respects_lexical_scoping(self): def test_respects_lexical_scoping(self):
def fn(x, k): def fn(x, k):
@ -346,12 +428,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
z = y * k z = y * k
return z, k return z, k
graph = self._perform_ad_subgraph_slicing(fn, 1, 1) graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
# We should not have combined the two multiplications into # We should not have combined the two multiplications into
# the same group; they should each be a separate DiffGraph # the same group; they should each be a separate DiffGraph
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 3) self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3)
def test_merge_respects_aliasing(self): def test_merge_respects_aliasing(self):
def fn(x, k, cond): def fn(x, k, cond):
@ -368,15 +448,13 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1) graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
# z2 did did not get merged into the subgraph # z2 did did not get merged into the subgraph
FileCheck().check("prim::If").check("aten::select").check_next("aten::select")\ FileCheck().check("prim::If").check("aten::select").check_next(
.check_next("aten::add_").check("Differentiable").run(graph) "aten::select"
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2) ).check_next("aten::add_").check("Differentiable").run(graph)
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
def test_aliased_outputs(self): def test_aliased_outputs(self):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
# Case 1: aliasing between relu and t # Case 1: aliasing between relu and t
# is within a DifferentiableGraph. It should be valid # is within a DifferentiableGraph. It should be valid
# to merge both split_with_sizes in relu in one graph # to merge both split_with_sizes in relu in one graph
@ -389,9 +467,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str) graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("with prim::DifferentiableGraph") \ FileCheck().check("with prim::DifferentiableGraph").check(
.check("aten::relu").check("aten::t") \ "aten::relu"
.run(graph) ).check("aten::t").run(graph)
# Case 2: aliasing between relu and split_with_sizes # Case 2: aliasing between relu and split_with_sizes
# are both outputs of a Diff graph. It should be invalid # are both outputs of a Diff graph. It should be invalid
@ -410,11 +488,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str) graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \ FileCheck().check("Tensor = prim::DifferentiableGraph").check(
.check("with prim::DifferentiableGraph") \ "with prim::DifferentiableGraph"
.check("Tensor = aten::relu") \ ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
.check_not("aten::split_with_sizes") \ graph
.run(graph) )
# Case 3: two aliased nodes in a graph. # Case 3: two aliased nodes in a graph.
# Both `split_with_sizes` should be unfused # Both `split_with_sizes` should be unfused
@ -432,11 +510,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str) graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \ FileCheck().check("Tensor = prim::DifferentiableGraph").check(
.check("with prim::DifferentiableGraph") \ "with prim::DifferentiableGraph"
.check("Tensor = aten::relu") \ ).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
.check_not("aten::split_with_sizes") \ graph
.run(graph) )
# Case 4: the aliased output has a descendant # Case 4: the aliased output has a descendant
# Both should be unfused. Note, %3 comes before %2 # Both should be unfused. Note, %3 comes before %2
@ -454,11 +532,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str) graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \ FileCheck().check("Tensor = prim::DifferentiableGraph").check(
.check("with prim::DifferentiableGraph") \ "with prim::DifferentiableGraph"
.check("Tensor = aten::relu") \ ).check("Tensor = aten::relu").check_not("aten::t").run(graph)
.check_not("aten::t") \
.run(graph)
# Case 5: multiple aliased groups # Case 5: multiple aliased groups
# Both should be unfused. Note, %3 comes before %2 # Both should be unfused. Note, %3 comes before %2
@ -478,11 +554,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
graph = torch._C.parse_ir(input_str) graph = torch._C.parse_ir(input_str)
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
FileCheck().check("Tensor = prim::DifferentiableGraph") \ FileCheck().check("Tensor = prim::DifferentiableGraph").check(
.check("with prim::DifferentiableGraph") \ "with prim::DifferentiableGraph"
.check("Tensor = aten::relu") \ ).check("Tensor = aten::relu").check_not("aten::t").run(graph)
.check_not("aten::t") \
.run(graph)
def test_has_profiled_info_aliasing_outputs(self): def test_has_profiled_info_aliasing_outputs(self):
# The expectation is that CallFunction will prevent the final profile node from # The expectation is that CallFunction will prevent the final profile node from
@ -511,9 +585,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
output = outputs[0] output = outputs[0]
self.assertEqual(False, output.requiresGrad()) self.assertEqual(False, output.requiresGrad())
FileCheck().check("= prim::DifferentiableGraph") \ FileCheck().check("= prim::DifferentiableGraph").check(
.check("with prim::DifferentiableGraph") \ "with prim::DifferentiableGraph"
.check(" = aten::relu") \ ).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
.check("requires_grad=0") \
.check("aten::relu") \
.run(graph)

View File

@ -1,18 +1,19 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import io import io
import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import make_global
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch
from torch import Tensor from torch import Tensor
from torch._awaits import _Await as Await from torch._awaits import _Await as Await
from torch.testing._internal.jit_utils import JitTestCase, make_global
class TestAwait(JitTestCase): class TestAwait(JitTestCase):
def test_await_python(self): def test_await_python(self):
def foo(x: int) -> int: def foo(x: int) -> int:
return x + 13 return x + 13
aw: Await[int] = torch.jit._awaitable(foo, 13) aw: Await[int] = torch.jit._awaitable(foo, 13)
self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw)) self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
nw = torch.jit._awaitable_nowait(33) nw = torch.jit._awaitable_nowait(33)
@ -22,6 +23,7 @@ class TestAwait(JitTestCase):
def test_await_type_python(self): def test_await_type_python(self):
def foo() -> Tensor: def foo() -> Tensor:
return torch.randn() return torch.randn()
awaits = torch.jit.annotate(List[Await[Tensor]], []) awaits = torch.jit.annotate(List[Await[Tensor]], [])
awaits.append(torch.jit._awaitable(foo)) awaits.append(torch.jit._awaitable(foo))
@ -82,9 +84,7 @@ class TestAwait(JitTestCase):
self.assertTrue(torch.allclose(torch.eye(2), script_out)) self.assertTrue(torch.allclose(torch.eye(2), script_out))
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
def test_await_class_arg(self): def test_await_class_arg(self):
class C: class C:
def __init__(self, a: Tensor, b: Tensor): def __init__(self, a: Tensor, b: Tensor):
self.__a = a self.__a = a
@ -104,6 +104,7 @@ class TestAwait(JitTestCase):
_a = torch.eye(2) _a = torch.eye(2)
c2_t = torch.jit._awaitable_wait(aw) c2_t = torch.jit._awaitable_wait(aw)
return _a + c2_t + x return _a + c2_t + x
inp = torch.zeros(2) inp = torch.zeros(2)
sm = torch.jit.script(fn) sm = torch.jit.script(fn)
@ -120,7 +121,6 @@ class TestAwait(JitTestCase):
self._a = a self._a = a
self._b = b self._b = b
make_global(C) make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations # Can not stay in the class as Jit does not support Recursive annotations
@ -143,7 +143,6 @@ class TestAwait(JitTestCase):
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
def test_await_class_return(self): def test_await_class_return(self):
class C: class C:
__slots__ = ["a", "b"] __slots__ = ["a", "b"]
@ -151,7 +150,6 @@ class TestAwait(JitTestCase):
self.a = a self.a = a
self.b = b self.b = b
make_global(C) make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations # Can not stay in the class as Jit does not support Recursive annotations
@ -175,7 +173,9 @@ class TestAwait(JitTestCase):
script_out = sm(inp) script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out)) self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=1) self.assertGraphContainsExactly(
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
)
def test_await_getattr_implicit_convertion(self): def test_await_getattr_implicit_convertion(self):
class C: class C:
@ -186,7 +186,6 @@ class TestAwait(JitTestCase):
def b(self): def b(self):
return self._b return self._b
make_global(C) make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations # Can not stay in the class as Jit does not support Recursive annotations
@ -212,10 +211,11 @@ class TestAwait(JitTestCase):
script_out = sm(inp) script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out)) self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=2) self.assertGraphContainsExactly(
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
)
def test_await_nested(self): def test_await_nested(self):
class C: class C:
def __init__(self, a: Tensor, b: Tensor): def __init__(self, a: Tensor, b: Tensor):
self.__a = a self.__a = a
@ -250,6 +250,7 @@ class TestAwait(JitTestCase):
def __init__(self, v): def __init__(self, v):
self.parent = torch.jit.annotate(Optional[Tree], None) self.parent = torch.jit.annotate(Optional[Tree], None)
self.v = v self.v = v
make_global(Tree) make_global(Tree)
def delayed(t: Tree): def delayed(t: Tree):
@ -275,12 +276,15 @@ class TestAwait(JitTestCase):
sm = torch.jit.script(main) sm = torch.jit.script(main)
out = main(inp) out = main(inp)
script_out = sm(inp) script_out = sm(inp)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)) self.assertTrue(
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
)
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
def test_await_eager_lazy(self): def test_await_eager_lazy(self):
def delayed(x: Tensor) -> Tensor: def delayed(x: Tensor) -> Tensor:
return 2 * (x + 1) return 2 * (x + 1)
t = torch.ones(2, dtype=torch.int64) t = torch.ones(2, dtype=torch.int64)
aw = torch.jit._awaitable(delayed, t) aw = torch.jit._awaitable(delayed, t)
self.assertTrue(isinstance(aw, torch._C._Await)) self.assertTrue(isinstance(aw, torch._C._Await))
@ -302,7 +306,9 @@ class TestAwait(JitTestCase):
script_out_aw = sm(inp) script_out_aw = sm(inp)
script_out = torch.jit._awaitable_wait(script_out_aw) script_out = torch.jit._awaitable_wait(script_out_aw)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)) self.assertTrue(
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
)
self.assertTrue(torch.allclose(script_out, out)) self.assertTrue(torch.allclose(script_out, out))
def test_jit_trace(self): def test_jit_trace(self):

View File

@ -3,10 +3,10 @@
import os import os
import sys import sys
import unittest import unittest
from pathlib import Path
import torch import torch
import torch._C import torch._C
from pathlib import Path
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
# hacky way to skip these tests in fbcode: # hacky way to skip these tests in fbcode:
@ -15,9 +15,11 @@ from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
# it sees tests but then fails when it tries to actuall run them. # it sees tests but then fails when it tries to actuall run them.
if not IS_FBCODE: if not IS_FBCODE:
from test_nnapi import TestNNAPI from test_nnapi import TestNNAPI
HAS_TEST_NNAPI = True HAS_TEST_NNAPI = True
else: else:
from torch.testing._internal.common_utils import TestCase as TestNNAPI from torch.testing._internal.common_utils import TestCase as TestNNAPI
HAS_TEST_NNAPI = False HAS_TEST_NNAPI = False
@ -39,10 +41,14 @@ without the delegate API.
""" """
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests. # First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
torch_root = Path(__file__).resolve().parent.parent.parent torch_root = Path(__file__).resolve().parent.parent.parent
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so' lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
@skipIfTorchDynamo("weird py38 failures") @skipIfTorchDynamo("weird py38 failures")
@unittest.skipIf(not os.path.exists(lib_path), @unittest.skipIf(
"Skipping the test as libnnapi_backend.so was not found") not os.path.exists(lib_path),
"Skipping the test as libnnapi_backend.so was not found",
)
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found") @unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
class TestNnapiBackend(TestNNAPI): class TestNnapiBackend(TestNNAPI):
def setUp(self): def setUp(self):
@ -89,35 +95,44 @@ method_compile_spec must use the following format:
# No forward key # No forward key
compile_spec = {"backward": {"inputs": args}} compile_spec = {"backward": {"inputs": args}}
with self.assertRaisesRegex(RuntimeError, "method_compile_spec does not contain the \"forward\" key." + errorMsgTail): with self.assertRaisesRegex(
RuntimeError,
'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec) torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No dictionary under the forward key # No dictionary under the forward key
compile_spec = {"forward": 1} compile_spec = {"forward": 1}
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
"method_compile_spec does not contain a dictionary with an \"inputs\" key, " RuntimeError,
"under it's \"forward\" key." 'method_compile_spec does not contain a dictionary with an "inputs" key, '
+ errorMsgTail): 'under it\'s "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec) torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No inputs key (in the dictionary under the forward key) # No inputs key (in the dictionary under the forward key)
compile_spec = {"forward": {"not inputs": args}} compile_spec = {"forward": {"not inputs": args}}
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
"method_compile_spec does not contain a dictionary with an \"inputs\" key, " RuntimeError,
"under it's \"forward\" key." 'method_compile_spec does not contain a dictionary with an "inputs" key, '
+ errorMsgTail): 'under it\'s "forward" key.' + errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec) torch._C._jit_to_backend("nnapi", traced, compile_spec)
# No Tensor or TensorList under the inputs key # No Tensor or TensorList under the inputs key
compile_spec = {"forward": {"inputs": 1}} compile_spec = {"forward": {"inputs": 1}}
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key." RuntimeError,
+ errorMsgTail): 'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
+ errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec) torch._C._jit_to_backend("nnapi", traced, compile_spec)
compile_spec = {"forward": {"inputs": [1]}} compile_spec = {"forward": {"inputs": [1]}}
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key." RuntimeError,
+ errorMsgTail): 'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
+ errorMsgTail,
):
torch._C._jit_to_backend("nnapi", traced, compile_spec) torch._C._jit_to_backend("nnapi", traced, compile_spec)
def tearDown(self): def tearDown(self):

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from torch.testing._internal.jit_utils import JitTestCase
import io import io
import os import os
import sys import sys
@ -8,18 +7,20 @@ import unittest
import torch import torch
import torch._C import torch._C
from torch.testing import FileCheck
from torch.jit.mobile import _load_for_lite_interpreter from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing import FileCheck
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE, IS_FBCODE,
IS_MACOS, IS_MACOS,
IS_SANDCASTLE, IS_SANDCASTLE,
IS_WINDOWS, IS_WINDOWS,
TEST_WITH_ROCM,
skipIfRocm, skipIfRocm,
find_library_location, TEST_WITH_ROCM,
) )
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
@ -33,7 +34,9 @@ if __name__ == "__main__":
def to_test_backend(module, method_compile_spec): def to_test_backend(module, method_compile_spec):
return torch._C._jit_to_backend("test_backend", module, {"forward": method_compile_spec}) return torch._C._jit_to_backend(
"test_backend", module, {"forward": method_compile_spec}
)
def to_test_backend_multi(module, method_compile_spec): def to_test_backend_multi(module, method_compile_spec):
@ -63,8 +66,10 @@ class BasicModule(torch.nn.Module):
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. # This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, @unittest.skipIf(
"Non-portable load_library call used in test") TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCase(JitTestCase): class JitBackendTestCase(JitTestCase):
""" """
A common base class for JIT backend tests that contains common utility A common base class for JIT backend tests that contains common utility
@ -73,7 +78,7 @@ class JitBackendTestCase(JitTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
lib_file_path = find_library_location('libjitbackend_test.so') lib_file_path = find_library_location("libjitbackend_test.so")
torch.ops.load_library(str(lib_file_path)) torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up three variables in their setUp methods: # Subclasses are expected to set up three variables in their setUp methods:
# module - a regular, Python version of the module being tested # module - a regular, Python version of the module being tested
@ -154,13 +159,17 @@ class BasicModuleTest(JitBackendTestCase):
self.test_execution() self.test_execution()
# Save the compile spec to compare against the version retrieved after loading. # Save the compile spec to compare against the version retrieved after loading.
pre_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec") pre_compile_spec = self.lowered_module.__getattr__(
"__loweredModule__"
).__getattr__("__method_compile_spec")
# Save and load the lowered module. # Save and load the lowered module.
self.save_load() self.save_load()
# Get the compile spec after loading. # Get the compile spec after loading.
post_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec") post_compile_spec = self.lowered_module.__getattr__(
"__loweredModule__"
).__getattr__("__method_compile_spec")
# Compile specs should match. # Compile specs should match.
self.assertEqual(pre_compile_spec, post_compile_spec) self.assertEqual(pre_compile_spec, post_compile_spec)
@ -195,9 +204,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
input = torch.randn(5) input = torch.randn(5)
# Test exception is thrown. # Test exception is thrown.
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
r"Backend is not available.", Exception,
"raise Exception(\"Backend is not available.\""): r"Backend is not available.",
'raise Exception("Backend is not available."',
):
backend_method = self.lowered_module.__getattr__("forward") backend_method = self.lowered_module.__getattr__("forward")
backend_output = backend_method(*(input, input)) backend_output = backend_method(*(input, input))
@ -207,9 +218,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
buffer = io.BytesIO() buffer = io.BytesIO()
torch.jit.save(self.lowered_module, buffer) torch.jit.save(self.lowered_module, buffer)
buffer.seek(0) buffer.seek(0)
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
r"Backend is not available.", Exception,
"raise Exception(\"Backend is not available.\""): r"Backend is not available.",
'raise Exception("Backend is not available."',
):
imported = torch.jit.load(buffer) imported = torch.jit.load(buffer)
@ -218,6 +231,7 @@ class NestedModuleTest(JitBackendTestCase):
Tests for NestedModule that check that a module lowered to a backend can be used Tests for NestedModule that check that a module lowered to a backend can be used
as a submodule. as a submodule.
""" """
class NestedModule(torch.nn.Module): class NestedModule(torch.nn.Module):
""" """
A Module with one submodule that is used to test that lowered Modules A Module with one submodule that is used to test that lowered Modules
@ -237,7 +251,9 @@ class NestedModuleTest(JitBackendTestCase):
# Both modules in self.module are regular Python modules. # Both modules in self.module are regular Python modules.
self.module = NestedModuleTest.NestedModule(BasicModule()) self.module = NestedModuleTest.NestedModule(BasicModule())
# Both modules in self.scripted_module are ScriptModules. # Both modules in self.scripted_module are ScriptModules.
self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule())) self.scripted_module = torch.jit.script(
NestedModuleTest.NestedModule(BasicModule())
)
# First, script another instance of NestedModule with share_types=False so that it can be # First, script another instance of NestedModule with share_types=False so that it can be
# selectively lowered without modifying the type of self.scripted_module. # selectively lowered without modifying the type of self.scripted_module.
@ -246,7 +262,9 @@ class NestedModuleTest(JitBackendTestCase):
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
) )
# self.lowered_module is a ScriptModule, but its submodule is a lowered module. # self.lowered_module is a ScriptModule, but its submodule is a lowered module.
self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module)) self.lowered_module = torch.jit.script(
NestedModuleTest.NestedModule(lowered_module)
)
def test_execution(self): def test_execution(self):
# Test execution with backend against Python and JIT. # Test execution with backend against Python and JIT.
@ -270,6 +288,7 @@ class SelectiveLoweringTest(JitBackendTestCase):
""" """
Tests for the selective lowering API. Tests for the selective lowering API.
""" """
class OuterModule(torch.nn.Module): class OuterModule(torch.nn.Module):
def __init__(self, sub1, sub2, other): def __init__(self, sub1, sub2, other):
super().__init__() super().__init__()
@ -299,7 +318,10 @@ class SelectiveLoweringTest(JitBackendTestCase):
MiddleModule = SelectiveLoweringTest.MiddleModule MiddleModule = SelectiveLoweringTest.MiddleModule
def script_without_type_sharing(mod): def script_without_type_sharing(mod):
return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False) return torch.jit._recursive.create_script_module(
mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
)
# Create Python, JIT and backend versions of a hierarchy that looks like this: # Create Python, JIT and backend versions of a hierarchy that looks like this:
# --------- OuterModule -------- # --------- OuterModule --------
# | | | # | | |
@ -308,13 +330,28 @@ class SelectiveLoweringTest(JitBackendTestCase):
# BasicModule BasicModule BasicModule # BasicModule BasicModule BasicModule
# #
# Two BasicModules will be lowered and the third will not. # Two BasicModules will be lowered and the third will not.
self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) self.module = OuterModule(
self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule( MiddleModule(BasicModule()),
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) MiddleModule(BasicModule()),
self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule( MiddleModule(BasicModule()),
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))) )
self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [ self.scripted_module = script_without_type_sharing(
"sub1.submodule", "sub2.submodule"]) OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
)
self.lowered_module = script_without_type_sharing(
OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
)
self.lowered_module = to_test_backend_selective(
self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
)
def test_execution(self): def test_execution(self):
input = torch.randn(5) input = torch.randn(5)
@ -335,93 +372,93 @@ class SelectiveLoweringTest(JitBackendTestCase):
""" """
# Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
# calling the lowered module directly. # calling the lowered module directly.
FileCheck() \ FileCheck().check("OuterModule").check("BasicModule").run(
.check("OuterModule") \ self.scripted_module.graph
.check("BasicModule") \ )
.run(self.scripted_module.graph) FileCheck().check("OuterModule").check_not(
FileCheck() \ "__torch__.torch.classes.__backends__.test_backend"
.check("OuterModule") \ ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.graph)
# Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
FileCheck() \ FileCheck().check("MiddleModule").check("BasicModule").check_not(
.check("MiddleModule") \ "LoweredWrapper.test_backend"
.check("BasicModule") \ ).run(self.scripted_module.sub1.graph)
.check_not("LoweredWrapper.test_backend") \ FileCheck().check("MiddleModule").check_not(
.run(self.scripted_module.sub1.graph) "__torch__.torch.classes.__backends__.test_backend"
FileCheck() \ ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
.check("MiddleModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.sub1.graph)
FileCheck() \ FileCheck().check("MiddleModule").check("BasicModule").check_not(
.check("MiddleModule") \ "LoweredWrapper.test_backend"
.check("BasicModule") \ ).run(self.scripted_module.sub2.graph)
.check_not("LoweredWrapper.test_backend") \ FileCheck().check("MiddleModule").check_not(
.run(self.scripted_module.sub2.graph) "__torch__.torch.classes.__backends__.test_backend"
FileCheck() \ ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
.check("MiddleModule") \
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check("LoweredWrapper.test_backend") \
.run(self.lowered_module.sub2.graph)
# Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
# __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend, # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
# the TorchBind class for executing functions on the test JIT backend. # the TorchBind class for executing functions on the test JIT backend.
FileCheck() \ FileCheck().check("LoweredModule.test_backend").check(
.check("LoweredModule.test_backend") \ "__torch__.torch.classes.__backends__.test_backend"
.check("__torch__.torch.classes.__backends__.test_backend") \ ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
.run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
FileCheck() \ FileCheck().check("LoweredModule.test_backend").check(
.check("LoweredModule.test_backend") \ "__torch__.torch.classes.__backends__.test_backend"
.check("__torch__.torch.classes.__backends__.test_backend") \ ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
.run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
# Check that self.other and self.other.submodule have been left untouched by the selective lowering process. # Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
FileCheck() \ FileCheck().check("MiddleModule").check("BasicModule").check_not(
.check("MiddleModule") \ "__torch__.torch.classes.__backends__.test_backend"
.check("BasicModule") \ ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
.check_not("__torch__.torch.classes.__backends__.test_backend") \ FileCheck().check("BasicModule").check_not(
.check_not("LoweredWrapper.test_backend") \ "__torch__.torch.classes.__backends__.test_backend"
.run(self.scripted_module.other.graph) ).check_not("LoweredModule.test_backend").run(
FileCheck() \ self.scripted_module.other.submodule.graph
.check("BasicModule") \ )
.check_not("__torch__.torch.classes.__backends__.test_backend") \
.check_not("LoweredModule.test_backend") \
.run(self.scripted_module.other.submodule.graph)
def test_errors(self): def test_errors(self):
""" """
Check errors associated with selective lowering. Check errors associated with selective lowering.
""" """
# Check error messages thrown when attempting to lower something that is not a ScriptModule. # Check error messages thrown when attempting to lower something that is not a ScriptModule.
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Object .* is not a ScriptModule", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Object .* is not a ScriptModule", ""
):
to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
MiddleModule = SelectiveLoweringTest.MiddleModule MiddleModule = SelectiveLoweringTest.MiddleModule
mod = MiddleModule(BasicModule()) mod = MiddleModule(BasicModule())
mod.new_attr = 3 mod.new_attr = 3
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute named new_attr is not a Module", ""): with self.assertRaisesRegexWithHighlight(
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"]) RuntimeError, r"Attribute named new_attr is not a Module", ""
):
to_test_backend_selective(
torch.jit.script(mod), {"forward": ""}, ["new_attr"]
)
# Check error message thrown when module hierarchy doesn't have unique types. # Check error message thrown when module hierarchy doesn't have unique types.
OuterModule = SelectiveLoweringTest.OuterModule OuterModule = SelectiveLoweringTest.OuterModule
mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())) mod = OuterModule(
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
MiddleModule(BasicModule()),
)
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
r"Selective lowering is only supported for module hierarchies with unique types", RuntimeError,
""): r"Selective lowering is only supported for module hierarchies with unique types",
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]) "",
):
to_test_backend_selective(
torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
)
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. # This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, @unittest.skipIf(
"Non-portable load_library call used in test") TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class TestBackends(JitTestCase): class TestBackends(JitTestCase):
""" """
This class wraps and invokes all subclasses of JitBackendTestCase so that each one This class wraps and invokes all subclasses of JitBackendTestCase so that each one
@ -461,6 +498,7 @@ class TestBackends(JitTestCase):
def test_errors(self): def test_errors(self):
self.selective_lowering_test.test_errors() self.selective_lowering_test.test_errors()
""" """
Unit Tests for backend with compiler Unit Tests for backend with compiler
This test case and the existing TestBackends are separate because they cover different aspects. This test case and the existing TestBackends are separate because they cover different aspects.
@ -468,6 +506,8 @@ The actual backend implementation in this test is different.
It has a simple demo compiler to test the end-to-end flow in mobile. It has a simple demo compiler to test the end-to-end flow in mobile.
However, this test cannot cover the selective_lowering for now, which is covered in TestBackends. However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
""" """
class BasicModuleAdd(torch.nn.Module): class BasicModuleAdd(torch.nn.Module):
""" """
A simple add Module used to test to_backend lowering machinery. A simple add Module used to test to_backend lowering machinery.
@ -476,9 +516,12 @@ class BasicModuleAdd(torch.nn.Module):
def forward(self, x, h): def forward(self, x, h):
return x + h return x + h
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. # This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, @unittest.skipIf(
"Non-portable load_library call used in test") TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class JitBackendTestCaseWithCompiler(JitTestCase): class JitBackendTestCaseWithCompiler(JitTestCase):
""" """
A common base class for JIT backend tests with compilers that contains common utility A common base class for JIT backend tests with compilers that contains common utility
@ -487,7 +530,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
lib_file_path = find_library_location('libbackend_with_compiler.so') lib_file_path = find_library_location("libbackend_with_compiler.so")
torch.ops.load_library(str(lib_file_path)) torch.ops.load_library(str(lib_file_path))
# Subclasses are expected to set up four variables in their setUp methods: # Subclasses are expected to set up four variables in their setUp methods:
# module - a regular, Python version of the module being tested # module - a regular, Python version of the module being tested
@ -524,6 +567,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
""" """
pass pass
class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler): class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
""" """
Tests for BasicModuleAdd. Tests for BasicModuleAdd.
@ -541,7 +585,8 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
}, },
} }
self.lowered_module = torch._C._jit_to_backend( self.lowered_module = torch._C._jit_to_backend(
"backend_with_compiler_demo", self.scripted_module, compile_spec) "backend_with_compiler_demo", self.scripted_module, compile_spec
)
# Create mobile version of BasicModuleAdd # Create mobile version of BasicModuleAdd
buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -552,6 +597,7 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
input = torch.ones(1, dtype=torch.float) input = torch.ones(1, dtype=torch.float)
self.check_forward((input, input)) self.check_forward((input, input))
class ErrorMessagesWithCompiler(JitBackendTestCase): class ErrorMessagesWithCompiler(JitBackendTestCase):
""" """
Tests for errors that occur with compiler, specifically: Tests for errors that occur with compiler, specifically:
@ -562,22 +608,31 @@ class ErrorMessagesWithCompiler(JitBackendTestCase):
""" """
A module with an operator that is not supported. A module with an operator that is not supported.
""" """
def forward(self, x, h): def forward(self, x, h):
return x * h return x * h
self._loweredmodule.forward() self._loweredmodule.forward()
def test_errors(self): def test_errors(self):
scripted_module_n = torch.jit.script(ErrorMessagesWithCompiler.ModuleNotSupported()) scripted_module_n = torch.jit.script(
ErrorMessagesWithCompiler.ModuleNotSupported()
)
# Test exception is thrown when lowering a module with an unsupported operator # Test exception is thrown when lowering a module with an unsupported operator
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
# Special escape characters are replaced with '.' RuntimeError,
r"""The node of aten::mul is not supported in this compiler. .* # Special escape characters are replaced with '.'
r"""The node of aten::mul is not supported in this compiler. .*
def forward.self, x, h.: def forward.self, x, h.:
return x . h return x . h
~~~~~ <--- HERE ~~~~~ <--- HERE
self._loweredmodule.forward.. self._loweredmodule.forward..
""", ""): """,
lowered_module_n = torch._C._jit_to_backend("backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}) "",
):
lowered_module_n = torch._C._jit_to_backend(
"backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
)
class CompModuleTestWithCompiler(JitBackendTestCase): class CompModuleTestWithCompiler(JitBackendTestCase):
""" """
@ -588,6 +643,7 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
""" """
A simple subtraction Module to be used in CompModule. A simple subtraction Module to be used in CompModule.
""" """
def forward(self, x, h): def forward(self, x, h):
return x - h return x - h
@ -617,14 +673,19 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
}, },
} }
lowered_add = torch._C._jit_to_backend( lowered_add = torch._C._jit_to_backend(
"backend_with_compiler_demo", torch.jit.script(BasicModuleAdd()), compile_spec) "backend_with_compiler_demo",
torch.jit.script(BasicModuleAdd()),
compile_spec,
)
lowered_sub = torch._C._jit_to_backend( lowered_sub = torch._C._jit_to_backend(
"backend_with_compiler_demo", "backend_with_compiler_demo",
torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()), torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
{"forward": {"": ""}} {"forward": {"": ""}},
) )
self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
self.scripted_module = torch.jit.script(CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)) self.scripted_module = torch.jit.script(
CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
)
# No backend version of CompModule currently, so this is filler. # No backend version of CompModule currently, so this is filler.
self.lowered_module = self.scripted_module self.lowered_module = self.scripted_module
# Create a mobile version of CompModule from JIT version # Create a mobile version of CompModule from JIT version
@ -640,9 +701,12 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
# Test forward. # Test forward.
self.check_function("forward", (input1, input2, input2)) self.check_function("forward", (input1, input2, input2))
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests. # This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
@unittest.skipIf(IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, @unittest.skipIf(
"Non-portable load_library call used in test") IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
"Non-portable load_library call used in test",
)
class TestBackendsWithCompiler(JitTestCase): class TestBackendsWithCompiler(JitTestCase):
""" """
This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
@ -711,7 +775,6 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
y = s * (c * d) y = s * (c * d)
return y return y
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -728,6 +791,7 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
# Test forward. # Test forward.
self.check_function("forward", (a, b, s)) self.check_function("forward", (a, b, s))
class AddedAttributesTest(JitBackendTestCase): class AddedAttributesTest(JitBackendTestCase):
""" """
Tests for adding attributes to a model after lowering. Tests for adding attributes to a model after lowering.
@ -747,11 +811,19 @@ class AddedAttributesTest(JitBackendTestCase):
input = [(torch.ones(5),)] input = [(torch.ones(5),)]
pre_bundled = self.lowered_module(*input[0]) pre_bundled = self.lowered_module(*input[0])
# Attach bundled inputs which adds several attributes and functions to the model # Attach bundled inputs which adds several attributes and functions to the model
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input) # noqa: F821 self.lowered_module = (
post_bundled = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0]) torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
lowered_module, input # noqa: F821
)
)
post_bundled = self.lowered_module(
*self.lowered_module.get_all_bundled_inputs()[0]
)
# Save and load the lowered module. # Save and load the lowered module.
self.save_load() self.save_load()
# Use bundled after save and load to prove its preserved # Use bundled after save and load to prove its preserved
post_load = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0]) post_load = self.lowered_module(
*self.lowered_module.get_all_bundled_inputs()[0]
)
self.assertEqual(pre_bundled, post_bundled) self.assertEqual(pre_bundled, post_bundled)
self.assertEqual(post_bundled, post_load) self.assertEqual(post_bundled, post_load)

View File

@ -56,7 +56,6 @@ class TestBatchMM(JitTestCase):
actual = test_batch_mm_scripted(*tensors) actual = test_batch_mm_scripted(*tensors)
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9) self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
def test_batch_mm_permitted_mutation(self): def test_batch_mm_permitted_mutation(self):
def test_batch_mm( def test_batch_mm(
T1: torch.Tensor, T1: torch.Tensor,

View File

@ -1,8 +1,8 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import inspect
import os import os
import sys import sys
import inspect
import unittest import unittest
from typing import Dict, List from typing import Dict, List
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestBuiltins(JitTestCase): class TestBuiltins(JitTestCase):
@ -86,24 +88,27 @@ class TestBuiltins(JitTestCase):
self.checkScript(fn, ([1, 2, 3],)) self.checkScript(fn, ([1, 2, 3],))
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
a = x ** 2 a = x**2
del a del a
return a # noqa: F821 return a # noqa: F821
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"): with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
a = x ** 2 a = x**2
if a: if a:
del a del a
return a return a
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"): with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"):
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
a = x ** 2 a = x**2
del b # noqa: F821 del b # noqa: F821
return a return a
@ -124,7 +129,7 @@ class TestBuiltins(JitTestCase):
self.assertEqual(py_out, jit_out) self.assertEqual(py_out, jit_out)
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]: def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
del x['hi'], x['there'] del x["hi"], x["there"]
return x return x
py_out = del_dict_multiple_operands({"hi": 5, "there": 6}) py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
@ -137,7 +142,7 @@ class TestTensorBuiltins(JitTestCase):
def should_keep(tensor, name): def should_keep(tensor, name):
if inspect.isroutine(getattr(tensor, name)): if inspect.isroutine(getattr(tensor, name)):
return False return False
if name.startswith('_'): if name.startswith("_"):
return False return False
return True return True
@ -145,8 +150,8 @@ class TestTensorBuiltins(JitTestCase):
keys = dir(tensor) keys = dir(tensor)
# real and imag are only implemented for complex tensors. # real and imag are only implemented for complex tensors.
self.assertRaises(RuntimeError, lambda: should_keep(tensor, 'imag')) self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag"))
keys.remove('imag') keys.remove("imag")
properties = [p for p in keys if should_keep(tensor, p)] properties = [p for p in keys if should_keep(tensor, p)]
@ -158,16 +163,16 @@ class TestTensorBuiltins(JitTestCase):
EQUALITY_MISMATCH = { EQUALITY_MISMATCH = {
# TorchScript doesn't have real enums so they return an int instead # TorchScript doesn't have real enums so they return an int instead
# of the actual value # of the actual value
'dtype', "dtype",
'layout', "layout",
} }
MISSING_PROPERTIES = { MISSING_PROPERTIES = {
'grad_fn', "grad_fn",
# This is an undocumented property so it's not included # This is an undocumented property so it's not included
"output_nr", "output_nr",
# This has a longer implementation, maybe not worth copying to # This has a longer implementation, maybe not worth copying to
# TorchScript if named tensors don't work there anyways # TorchScript if named tensors don't work there anyways
'names', "names",
} }
for p in properties: for p in properties:
@ -232,7 +237,8 @@ class TestTensorBuiltins(JitTestCase):
def func(): def func():
c = 1 c = 1
return c.add(1) return c.add(1)
with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
torch.jit.script(func) torch.jit.script(func)
# testing implicit conversion of tensors to scalars to match function arguments # testing implicit conversion of tensors to scalars to match function arguments
@ -265,10 +271,12 @@ class TestTensorBuiltins(JitTestCase):
x = torch.zeros(10) x = torch.zeros(10)
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor) # float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
tensors = [torch.tensor(1.1), tensors = [
torch.tensor(1.1, requires_grad=True), torch.tensor(1.1),
torch.tensor(0), torch.tensor(1.1, requires_grad=True),
torch.tensor([2])] torch.tensor(0),
torch.tensor([2]),
]
script_funs = [tensor_to_int_script, tensor_to_float_script] script_funs = [tensor_to_int_script, tensor_to_float_script]
funs = [tensor_to_int, tensor_to_float] funs = [tensor_to_int, tensor_to_float]
@ -286,4 +294,6 @@ class TestTensorBuiltins(JitTestCase):
# assert result or exception equal for each (function, inputs) # assert result or exception equal for each (function, inputs)
for tensor in tensors: for tensor in tensors:
for i in range(len(script_funs)): for i in range(len(script_funs)):
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)) self.assertEqual(
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
)

View File

@ -4,24 +4,28 @@ import io
import os import os
import sys import sys
import unittest import unittest
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.testing import FileCheck from torch.testing import FileCheck
from typing import Any
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global from typing import Dict, Iterable, List, Optional, Tuple
import torch.testing._internal.jit_utils import torch.testing._internal.jit_utils
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
from typing import List, Tuple, Iterable, Optional, Dict from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestClassType(JitTestCase): class TestClassType(JitTestCase):
def test_reference_semantics(self): def test_reference_semantics(self):
@ -29,6 +33,7 @@ class TestClassType(JitTestCase):
Test that modifications made to a class instance in TorchScript Test that modifications made to a class instance in TorchScript
are visible in eager. are visible in eager.
""" """
class Foo: class Foo:
def __init__(self, a: int): def __init__(self, a: int):
self.a = a self.a = a
@ -92,12 +97,12 @@ class TestClassType(JitTestCase):
pass pass
def __contains__(self, key: str) -> bool: def __contains__(self, key: str) -> bool:
return key == 'hi' return key == "hi"
@torch.jit.script @torch.jit.script
def fn(): def fn():
foo = FooTest() foo = FooTest()
return 'hi' in foo, 'no' in foo return "hi" in foo, "no" in foo
self.assertEqual(fn(), (True, False)) self.assertEqual(fn(), (True, False))
@ -118,7 +123,10 @@ class TestClassType(JitTestCase):
self.assertEqual(fn(1), 3) self.assertEqual(fn(1), 3)
def test_set_attr_type_mismatch(self): def test_set_attr_type_mismatch(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"
):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
@ -126,7 +134,10 @@ class TestClassType(JitTestCase):
self.foo = 10 # should error since int != Tensor self.foo = 10 # should error since int != Tensor
def test_get_attr_not_initialized(self): def test_get_attr_not_initialized(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.asdf"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", "self.asdf"
):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
@ -136,7 +147,10 @@ class TestClassType(JitTestCase):
return self.asdf # asdf isn't an attr return self.asdf # asdf isn't an attr
def test_set_attr_non_initialized(self): def test_set_attr_non_initialized(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"
):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
@ -153,12 +167,16 @@ class TestClassType(JitTestCase):
Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'. Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'.
""" """
with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""): with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""):
@torch.jit.script @torch.jit.script
def FooTest(x): def FooTest(x):
return torch.nn.functional.interpolate(x, 'bad') return torch.nn.functional.interpolate(x, "bad")
def test_type_annotations(self): def test_type_annotations(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "Expected a value of type \'bool", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Expected a value of type 'bool", ""
):
@torch.jit.script # noqa: B903 @torch.jit.script # noqa: B903
class FooTest: # noqa: B903 class FooTest: # noqa: B903
def __init__(self, x: bool) -> None: def __init__(self, x: bool) -> None:
@ -171,7 +189,10 @@ class TestClassType(JitTestCase):
fn(2) fn(2)
def test_conditional_set_attr(self): def test_conditional_set_attr(self):
with self.assertRaisesRegexWithHighlight(RuntimeError, "assignment cannot be in a control-flow block", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "assignment cannot be in a control-flow block", ""
):
@torch.jit.script @torch.jit.script
class FooTest: class FooTest:
def __init__(self, x): def __init__(self, x):
@ -236,7 +257,6 @@ class TestClassType(JitTestCase):
# classes are globally registered for now, so we need to clear the JIT # classes are globally registered for now, so we need to clear the JIT
# registry to simulate loading a new model # registry to simulate loading a new model
buffer.seek(0) buffer.seek(0)
m_loaded = torch.jit.load(buffer) m_loaded = torch.jit.load(buffer)
@ -320,7 +340,7 @@ class TestClassType(JitTestCase):
self.x = x self.x = x
self.y = y self.y = y
make_global(Foo) # see [local resolution in python] make_global(Foo) # see [local resolution in python]
@torch.jit.script @torch.jit.script
def use_foo(foo: Foo) -> Foo: def use_foo(foo: Foo) -> Foo:
@ -419,15 +439,22 @@ class TestClassType(JitTestCase):
self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)]) self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)])
with self.assertRaisesRegexWithHighlight(RuntimeError, "bool\' for argument \'reverse", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "bool' for argument 'reverse", ""
):
@torch.jit.script @torch.jit.script
def test(): def test():
li = [Foo(1)] li = [Foo(1)]
li.sort(li) li.sort(li)
return li return li
test() test()
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "must define a __lt__", ""
):
@torch.jit.script @torch.jit.script
class NoMethod: class NoMethod:
def __init__(self): def __init__(self):
@ -438,6 +465,7 @@ class TestClassType(JitTestCase):
li = [NoMethod(), NoMethod()] li = [NoMethod(), NoMethod()]
li.sort() li.sort()
return li return li
test() test()
@torch.jit.script @torch.jit.script
@ -449,12 +477,16 @@ class TestClassType(JitTestCase):
def __lt__(self, other): def __lt__(self, other):
pass pass
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "must define a __lt__", ""
):
@torch.jit.script @torch.jit.script
def test(): def test():
li = [WrongLt(), WrongLt()] li = [WrongLt(), WrongLt()]
li.sort() li.sort()
return li return li
test() test()
def test_class_inheritance(self): def test_class_inheritance(self):
@ -466,18 +498,21 @@ class TestClassType(JitTestCase):
def two(self, x): def two(self, x):
return x + self.b return x + self.b
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not support inheritance", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not support inheritance", ""
):
@torch.jit.script @torch.jit.script
class Derived(Base): class Derived(Base):
def two(self, x): def two(self, x):
return x + self.b + 2 return x + self.b + 2
def test_class_inheritance_implicit(self): def test_class_inheritance_implicit(self):
""" """
Test that inheritance is detected in Test that inheritance is detected in
implicit scripting codepaths (e.g. try_ann_to_type). implicit scripting codepaths (e.g. try_ann_to_type).
""" """
class A: class A:
def __init__(self, t): def __init__(self, t):
self.t = t self.t = t
@ -502,14 +537,16 @@ class TestClassType(JitTestCase):
else: else:
return B.f(x.t) return B.f(x.t)
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", ""
):
sc = torch.jit.script(fun) sc = torch.jit.script(fun)
@skipIfTorchDynamo("Test does not work with TorchDynamo") @skipIfTorchDynamo("Test does not work with TorchDynamo")
@unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode") @unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode")
def test_imported_classes(self): def test_imported_classes(self):
import jit._imported_class_test.foo
import jit._imported_class_test.bar import jit._imported_class_test.bar
import jit._imported_class_test.foo
import jit._imported_class_test.very.very.nested import jit._imported_class_test.very.very.nested
class MyMod(torch.jit.ScriptModule): class MyMod(torch.jit.ScriptModule):
@ -593,6 +630,7 @@ class TestClassType(JitTestCase):
def one(self, x, y): def one(self, x, y):
return x + y return x + y
# missing two # missing two
@torch.jit.script @torch.jit.script
@ -616,6 +654,7 @@ class TestClassType(JitTestCase):
x = c[i].one(x, x) x = c[i].one(x, x)
x = c[i].two(x) x = c[i].two(x)
return x return x
self.checkScript(use_them, (torch.rand(3, 4),)) self.checkScript(use_them, (torch.rand(3, 4),))
@torch.jit.script @torch.jit.script
@ -626,22 +665,33 @@ class TestClassType(JitTestCase):
def inherit(x: OneTwoThree) -> OneTwo: def inherit(x: OneTwoThree) -> OneTwo:
return as_interface(x) return as_interface(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not have method", ""
):
@torch.jit.script @torch.jit.script
def wrong1(): def wrong1():
return as_interface(NotMember()) return as_interface(NotMember())
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "is not compatible with interface", ""
):
@torch.jit.script @torch.jit.script
def wrong2(): def wrong2():
return as_interface(NotMember2()) return as_interface(NotMember2())
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "does not have method", ""
):
@torch.jit.script @torch.jit.script
def wrong3(): def wrong3():
return inherit(as_interface(Foo())) return inherit(as_interface(Foo()))
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "is not compatible with interface", ""
):
@torch.jit.script @torch.jit.script
def wrong4(x: OneTwoWrong) -> int: def wrong4(x: OneTwoWrong) -> int:
@ -656,7 +706,7 @@ class TestClassType(JitTestCase):
def forward(self, x): def forward(self, x):
return self.proxy_mod.two(x) return self.proxy_mod.two(x)
TestPyAssign.__annotations__ = {'proxy_mod': OneTwo} TestPyAssign.__annotations__ = {"proxy_mod": OneTwo}
input = torch.rand(3, 4) input = torch.rand(3, 4)
scripted_pyassign_mod = torch.jit.script(TestPyAssign()) scripted_pyassign_mod = torch.jit.script(TestPyAssign())
@ -671,10 +721,11 @@ class TestClassType(JitTestCase):
def forward(self, x): def forward(self, x):
return self.proxy_mod.two(x) return self.proxy_mod.two(x)
TestPyAssignError.__annotations__ = {'proxy_mod': OneTwoThree} TestPyAssignError.__annotations__ = {"proxy_mod": OneTwoThree}
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"is not compatible with interface __torch__", ""): RuntimeError, "is not compatible with interface __torch__", ""
):
torch.jit.script(TestPyAssignError(Foo())) torch.jit.script(TestPyAssignError(Foo()))
# test pure python object assignment to interface fails # test pure python object assignment to interface fails
@ -682,8 +733,9 @@ class TestClassType(JitTestCase):
def __init__(self): def __init__(self):
pass pass
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"the value is not a TorchScript compatible type", ""): RuntimeError, "the value is not a TorchScript compatible type", ""
):
torch.jit.script(TestPyAssignError(PyClass())) torch.jit.script(TestPyAssignError(PyClass()))
# TODO test: interface-interface class-interface inheritance errors, # TODO test: interface-interface class-interface inheritance errors,
# NamedTuple inheritance errors # NamedTuple inheritance errors
@ -729,7 +781,7 @@ class TestClassType(JitTestCase):
return self.x * other return self.x * other
def __pow__(self, other: int) -> int: def __pow__(self, other: int) -> int:
return int(self.x ** other) return int(self.x**other)
def __truediv__(self, other: int) -> float: def __truediv__(self, other: int) -> float:
return self.x / other return self.x / other
@ -773,54 +825,89 @@ class TestClassType(JitTestCase):
def __call__(self, val: int) -> int: def __call__(self, val: int) -> int:
return self.x * val * 3 return self.x * val * 3
make_global(Foo) # see [local resolution in python] make_global(Foo) # see [local resolution in python]
def add(): def add():
return MyClass(4) + 3 return MyClass(4) + 3
def sub(): # noqa: E306 def sub(): # noqa: E306
return MyClass(4) - 3 return MyClass(4) - 3
def mul(): # noqa: E306 def mul(): # noqa: E306
return MyClass(4) * 3 return MyClass(4) * 3
def pow(): # noqa: E306 def pow(): # noqa: E306
return MyClass(4) ** 3 return MyClass(4) ** 3
def truediv(): # noqa: E306 def truediv(): # noqa: E306
return MyClass(4) / 3 return MyClass(4) / 3
def ne(): # noqa: E306 def ne(): # noqa: E306
return MyClass(4) != 3 return MyClass(4) != 3
def eq(): # noqa: E306 def eq(): # noqa: E306
return MyClass(4) == 3 return MyClass(4) == 3
def lt(): # noqa: E306 def lt(): # noqa: E306
return MyClass(4) < 3 return MyClass(4) < 3
def gt(): # noqa: E306 def gt(): # noqa: E306
return MyClass(4) > 3 return MyClass(4) > 3
def le(): # noqa: E306 def le(): # noqa: E306
return MyClass(4) <= 3 return MyClass(4) <= 3
def ge(): # noqa: E306 def ge(): # noqa: E306
return MyClass(4) >= 3 return MyClass(4) >= 3
def _and(): # noqa: E306 def _and(): # noqa: E306
return MyClass(4) & 3 return MyClass(4) & 3
def _or(): # noqa: E306 def _or(): # noqa: E306
return MyClass(4) | 3 return MyClass(4) | 3
def _xor(): # noqa: E306 def _xor(): # noqa: E306
return MyClass(4) ^ 3 return MyClass(4) ^ 3
def getitem(): # noqa: E306 def getitem(): # noqa: E306
return MyClass(4)[1] return MyClass(4)[1]
def setitem(): # noqa: E306 def setitem(): # noqa: E306
a = MyClass(4) a = MyClass(4)
a[1] = 5 a[1] = 5
return a.x return a.x
def call(): # noqa: E306 def call(): # noqa: E306
a = MyClass(5) a = MyClass(5)
return a(2) return a(2)
ops = [add, sub, mul, pow, ne, eq, lt, gt, le, ge, _and, _or, _xor, getitem, setitem, call] ops = [
add,
sub,
mul,
pow,
ne,
eq,
lt,
gt,
le,
ge,
_and,
_or,
_xor,
getitem,
setitem,
call,
]
ops.append(truediv) ops.append(truediv)
for func in ops: for func in ops:
self.checkScript(func, ()) self.checkScript(func, ())
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", ""
):
@torch.jit.script @torch.jit.script
def test(): def test():
return Foo(torch.tensor(1)) + Foo(torch.tensor(1)) return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
@ -852,7 +939,7 @@ class TestClassType(JitTestCase):
fn = torch.jit.script(test) fn = torch.jit.script(test)
self.assertEqual(fn(Foo(0.5)), test(0.5)) self.assertEqual(fn(Foo(0.5)), test(0.5))
self.assertEqual(fn(Foo(0.)), test(0.0)) self.assertEqual(fn(Foo(0.0)), test(0.0))
# str has slightly different formatting # str has slightly different formatting
self.assertTrue("0.5" in (str(Foo(0.5)))) self.assertTrue("0.5" in (str(Foo(0.5))))
self.assertTrue("0." in (str(Foo(0.0)))) self.assertTrue("0." in (str(Foo(0.0))))
@ -865,7 +952,10 @@ class TestClassType(JitTestCase):
def __bool__(self): def __bool__(self):
return (1, 2) return (1, 2)
with self.assertRaisesRegexWithHighlight(RuntimeError, "expected a bool expression for condition", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected a bool expression for condition", ""
):
@torch.jit.script @torch.jit.script
def test(): def test():
if BadBool(): if BadBool():
@ -921,6 +1011,7 @@ class TestClassType(JitTestCase):
Recursive class types not yet supported. We should give a good error message. Recursive class types not yet supported. We should give a good error message.
""" """
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@torch.jit.script # noqa: B903 @torch.jit.script # noqa: B903
class Tree: # noqa: B903 class Tree: # noqa: B903
def __init__(self): def __init__(self):
@ -940,7 +1031,7 @@ class TestClassType(JitTestCase):
return x, y return x, y
# Test serialization/deserialization of class constant # Test serialization/deserialization of class constant
for c in (2, 1.0, None, True, 'str', (2, 3), [5.9, 7.3]): for c in (2, 1.0, None, True, "str", (2, 3), [5.9, 7.3]):
m = torch.jit.script(M(c)) m = torch.jit.script(M(c))
buffer = io.BytesIO() buffer = io.BytesIO()
torch.jit.save(m, buffer) torch.jit.save(m, buffer)
@ -954,28 +1045,31 @@ class TestClassType(JitTestCase):
def test_py_class_to_ivalue_missing_attribute(self): def test_py_class_to_ivalue_missing_attribute(self):
class Foo: class Foo:
i : int i: int
f : float f: float
def __init__(self, i : int, f : float): def __init__(self, i: int, f: float):
self.i = i self.i = i
self.f = f self.f = f
make_global(Foo) # see [local resolution in python] make_global(Foo) # see [local resolution in python]
@torch.jit.script @torch.jit.script
def test_fn(x : Foo) -> float: def test_fn(x: Foo) -> float:
return x.i + x.f return x.i + x.f
test_fn(Foo(3, 4.0)) test_fn(Foo(3, 4.0))
with self.assertRaisesRegexWithHighlight(RuntimeError, 'missing attribute i', ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "missing attribute i", ""
):
test_fn(torch.rand(3, 4)) test_fn(torch.rand(3, 4))
def test_unused_method(self): def test_unused_method(self):
""" """
Test unused methods on scripted classes. Test unused methods on scripted classes.
""" """
@torch.jit.script @torch.jit.script
class Unused: class Unused:
def __init__(self): def __init__(self):
@ -1028,12 +1122,13 @@ class TestClassType(JitTestCase):
Test that a scripted class can have a method that refers to the class itself Test that a scripted class can have a method that refers to the class itself
in its type annotations. in its type annotations.
""" """
@torch.jit.script @torch.jit.script
class Meta: class Meta:
def __init__(self, a: int): def __init__(self, a: int):
self.a = a self.a = a
def method(self, other: List['Meta']) -> 'Meta': def method(self, other: List["Meta"]) -> "Meta":
return Meta(len(other)) return Meta(len(other))
class ModuleWithMeta(torch.nn.Module): class ModuleWithMeta(torch.nn.Module):
@ -1051,19 +1146,20 @@ class TestClassType(JitTestCase):
""" """
Test that annotating container attributes with types works correctly Test that annotating container attributes with types works correctly
""" """
@torch.jit.script @torch.jit.script
class CompetitiveLinkingTokenReplacementUtils: class CompetitiveLinkingTokenReplacementUtils:
def __init__(self): def __init__(self):
self.my_list : List[Tuple[float, int, int]] = [] self.my_list: List[Tuple[float, int, int]] = []
self.my_dict : Dict[int, int] = {} self.my_dict: Dict[int, int] = {}
@torch.jit.script @torch.jit.script
def foo(): def foo():
y = CompetitiveLinkingTokenReplacementUtils() y = CompetitiveLinkingTokenReplacementUtils()
new_dict : Dict[int, int] = {1: 1, 2: 2} new_dict: Dict[int, int] = {1: 1, 2: 2}
y.my_dict = new_dict y.my_dict = new_dict
new_list : List[Tuple[float, int, int]] = [(1.0, 1, 1)] new_list: List[Tuple[float, int, int]] = [(1.0, 1, 1)]
y.my_list = new_list y.my_list = new_list
return y return y
@ -1071,6 +1167,7 @@ class TestClassType(JitTestCase):
""" """
Test that methods on class types can have default arguments. Test that methods on class types can have default arguments.
""" """
@torch.jit.script @torch.jit.script
class ClassWithDefaultArgs: class ClassWithDefaultArgs:
def __init__( def __init__(
@ -1105,7 +1202,9 @@ class TestClassType(JitTestCase):
return obj.int + obj.list[2] + obj.dict[1] return obj.int + obj.list[2] + obj.dict[1]
def override_defaults() -> int: def override_defaults() -> int:
obj: ClassWithDefaultArgs = ClassWithDefaultArgs(3, [9, 10, 11], (12, 13, 14), {3: 4}, "str") obj: ClassWithDefaultArgs = ClassWithDefaultArgs(
3, [9, 10, 11], (12, 13, 14), {3: 4}, "str"
)
s: int = obj.int s: int = obj.int
for x in obj.list: for x in obj.list:
@ -1154,7 +1253,7 @@ class TestClassType(JitTestCase):
# The constructor of this class below has mutable arguments. This should throw # The constructor of this class below has mutable arguments. This should throw
# an error. # an error.
class ClassWithMutableArgs: # noqa: B903 class ClassWithMutableArgs: # noqa: B903
def __init__( def __init__(
self, self,
a: List[int] = [1, 2, 3], # noqa: B006 a: List[int] = [1, 2, 3], # noqa: B006
@ -1164,13 +1263,16 @@ class TestClassType(JitTestCase):
def should_fail(): def should_fail():
obj: ClassWithMutableArgs = ClassWithMutableArgs() obj: ClassWithMutableArgs = ClassWithMutableArgs()
with self.assertRaisesRegexWithHighlight(RuntimeError, "Mutable default parameters are not supported", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Mutable default parameters are not supported", ""
):
torch.jit.script(should_fail) torch.jit.script(should_fail)
def test_staticmethod(self): def test_staticmethod(self):
""" """
Test static methods on class types. Test static methods on class types.
""" """
@torch.jit.script @torch.jit.script
class ClassWithStaticMethod: class ClassWithStaticMethod:
def __init__(self, a: int, b: int): def __init__(self, a: int, b: int):
@ -1183,22 +1285,22 @@ class TestClassType(JitTestCase):
def get_b(self): def get_b(self):
return self.b return self.b
def __eq__(self, other: 'ClassWithStaticMethod'): def __eq__(self, other: "ClassWithStaticMethod"):
return self.a == other.a and self.b == other.b return self.a == other.a and self.b == other.b
# staticmethod that calls constructor. # staticmethod that calls constructor.
@staticmethod @staticmethod
def create(args: List['ClassWithStaticMethod']) -> 'ClassWithStaticMethod': def create(args: List["ClassWithStaticMethod"]) -> "ClassWithStaticMethod":
return ClassWithStaticMethod(args[0].a, args[0].b) return ClassWithStaticMethod(args[0].a, args[0].b)
# staticmethod that calls another staticmethod. # staticmethod that calls another staticmethod.
@staticmethod @staticmethod
def create_from(a: int, b: int) -> 'ClassWithStaticMethod': def create_from(a: int, b: int) -> "ClassWithStaticMethod":
a = ClassWithStaticMethod(a, b) a = ClassWithStaticMethod(a, b)
return ClassWithStaticMethod.create([a]) return ClassWithStaticMethod.create([a])
# Script function that calls staticmethod. # Script function that calls staticmethod.
def test_function(a: int, b: int) -> 'ClassWithStaticMethod': def test_function(a: int, b: int) -> "ClassWithStaticMethod":
return ClassWithStaticMethod.create_from(a, b) return ClassWithStaticMethod.create_from(a, b)
make_global(ClassWithStaticMethod) make_global(ClassWithStaticMethod)
@ -1209,21 +1311,22 @@ class TestClassType(JitTestCase):
""" """
Test classmethods on class types. Test classmethods on class types.
""" """
@torch.jit.script @torch.jit.script
class ClassWithClassMethod: class ClassWithClassMethod:
def __init__(self, a: int): def __init__(self, a: int):
self.a: int = a self.a: int = a
def __eq__(self, other: 'ClassWithClassMethod'): def __eq__(self, other: "ClassWithClassMethod"):
return self.a == other.a return self.a == other.a
@classmethod @classmethod
def create(cls, a: int) -> 'ClassWithClassMethod': def create(cls, a: int) -> "ClassWithClassMethod":
return cls(a) return cls(a)
make_global(ClassWithClassMethod) make_global(ClassWithClassMethod)
def test_function(a: int) -> 'ClassWithClassMethod': def test_function(a: int) -> "ClassWithClassMethod":
x = ClassWithClassMethod(a) x = ClassWithClassMethod(a)
# Support calling classmethod with an instance # Support calling classmethod with an instance
# Calling with the class is not supported. # Calling with the class is not supported.
@ -1236,6 +1339,7 @@ class TestClassType(JitTestCase):
""" """
Test that a scripted class can make use of the @property decorator. Test that a scripted class can make use of the @property decorator.
""" """
def free_function(x: int) -> int: def free_function(x: int) -> int:
return x + 1 return x + 1
@ -1308,13 +1412,22 @@ class TestClassType(JitTestCase):
return self.props.attr + no_setter.attr + method_uses_property.forward() return self.props.attr + no_setter.attr + method_uses_property.forward()
self.checkModule(ModuleWithProperties(5), (5, 6, 7, 8,)) self.checkModule(
ModuleWithProperties(5),
(
5,
6,
7,
8,
),
)
def test_custom_delete(self): def test_custom_delete(self):
""" """
Test that del can be called on an instance of a class that Test that del can be called on an instance of a class that
overrides __delitem__. overrides __delitem__.
""" """
class Example: class Example:
def __init__(self): def __init__(self):
self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)} self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)}
@ -1346,7 +1459,9 @@ class TestClassType(JitTestCase):
del example[key] del example[key]
return example.check(key) return example.check(key)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Class does not define __delitem__", "example[key]"): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Class does not define __delitem__", "example[key]"
):
self.checkScript(fn, ()) self.checkScript(fn, ())
def test_recursive_script_builtin_type_resolution(self): def test_recursive_script_builtin_type_resolution(self):
@ -1369,7 +1484,7 @@ class TestClassType(JitTestCase):
def g(self, x: device_t) -> device_ty: def g(self, x: device_t) -> device_ty:
return x return x
def h(self, a: 'A') -> 'A': def h(self, a: "A") -> "A":
return A() return A()
def i(self, a: List[int]) -> int: def i(self, a: List[int]) -> int:
@ -1404,14 +1519,14 @@ class TestClassType(JitTestCase):
Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled
when compiling a module. when compiling a module.
""" """
class Wrapper():
class Wrapper:
def __init__(self, t): def __init__(self, t):
self.t = t self.t = t
def to(self, l: List[torch.device], device: Optional[torch.device] = None): def to(self, l: List[torch.device], device: Optional[torch.device] = None):
return self.t.to(device=device) return self.t.to(device=device)
class A(nn.Module): class A(nn.Module):
def forward(self): def forward(self):
return Wrapper(torch.rand(4, 4)) return Wrapper(torch.rand(4, 4))
@ -1424,6 +1539,7 @@ class TestClassType(JitTestCase):
Test that the error message displayed when convering a class type Test that the error message displayed when convering a class type
to an IValue that has an attribute of the wrong type. to an IValue that has an attribute of the wrong type.
""" """
@torch.jit.script # noqa: B903 @torch.jit.script # noqa: B903
class ValHolder: # noqa: B903 class ValHolder: # noqa: B903
def __init__(self, val): def __init__(self, val):
@ -1442,7 +1558,9 @@ class TestClassType(JitTestCase):
mod = self.mod2 mod = self.mod2
return mod.val return mod.val
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not cast attribute 'val' to type Tensor", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Could not cast attribute 'val' to type Tensor", ""
):
torch.jit.script(Mod()) torch.jit.script(Mod())
def test_recursive_scripting(self): def test_recursive_scripting(self):
@ -1450,6 +1568,7 @@ class TestClassType(JitTestCase):
Test that class types are recursively scripted when an Python instance of one Test that class types are recursively scripted when an Python instance of one
is encountered as a module attribute. is encountered as a module attribute.
""" """
class Class: class Class:
def __init__(self, a: int): def __init__(self, a: int):
self.a = a self.a = a
@ -1473,6 +1592,7 @@ class TestClassType(JitTestCase):
are added as failed attributes and do not cause compilation itself are added as failed attributes and do not cause compilation itself
to fail unless they are used in scripted code. to fail unless they are used in scripted code.
""" """
class UnscriptableClass: class UnscriptableClass:
def __init__(self, a: int): def __init__(self, a: int):
self.a = a self.a = a
@ -1490,7 +1610,9 @@ class TestClassType(JitTestCase):
def forward(self) -> bool: def forward(self) -> bool:
return self.obj.get_a() return self.obj.get_a()
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to convert Python type", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to convert Python type", ""
):
torch.jit.script(ShouldNotCompile(UnscriptableClass(4))) torch.jit.script(ShouldNotCompile(UnscriptableClass(4)))
# This Module has an attribute of type UnscriptableClass # This Module has an attribute of type UnscriptableClass
@ -1509,7 +1631,6 @@ class TestClassType(JitTestCase):
self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,)) self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,))
def test_unresolved_class_attributes(self): def test_unresolved_class_attributes(self):
class UnresolvedAttrClass: class UnresolvedAttrClass:
def __init__(self): def __init__(self):
@ -1538,7 +1659,9 @@ class TestClassType(JitTestCase):
u = UnresolvedAttrClass() u = UnresolvedAttrClass()
return u.attr_e return u.attr_e
error_message_regex = "object has no attribute or method.*is defined as a class attribute" error_message_regex = (
"object has no attribute or method.*is defined as a class attribute"
)
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e): for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
with self.assertRaisesRegex(RuntimeError, error_message_regex): with self.assertRaisesRegex(RuntimeError, error_message_regex):
torch.jit.script(fn) torch.jit.script(fn)

View File

@ -1,19 +1,21 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch import cmath
import os import os
import sys import sys
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from torch.testing._internal.common_utils import IS_MACOS
from typing import List, Dict
from itertools import product from itertools import product
from textwrap import dedent from textwrap import dedent
import cmath from typing import Dict, List
import torch
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
class TestComplex(JitTestCase): class TestComplex(JitTestCase):
def test_script(self): def test_script(self):
def fn(a: complex): def fn(a: complex):
@ -32,7 +34,7 @@ class TestComplex(JitTestCase):
def fn(a: Dict[complex, complex], key: complex) -> complex: def fn(a: Dict[complex, complex], key: complex) -> complex:
return a[key] return a[key]
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j} input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
self.checkScript(fn, (input, -4.3 - 2j)) self.checkScript(fn, (input, -4.3 - 2j))
def test_pickle(self): def test_pickle(self):
@ -41,7 +43,7 @@ class TestComplex(JitTestCase):
super().__init__() super().__init__()
self.a = 3 + 5j self.a = 3 + 5j
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j] self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j} self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
@torch.jit.script_method @torch.jit.script_method
def forward(self, b: int): def forward(self, b: int):
@ -50,7 +52,7 @@ class TestComplex(JitTestCase):
loaded = self.getExportImportCopy(ComplexModule()) loaded = self.getExportImportCopy(ComplexModule())
self.assertEqual(loaded.a, 3 + 5j) self.assertEqual(loaded.a, 3 + 5j)
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4]) self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}) self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
self.assertEqual(loaded(2), 2 + 2j) self.assertEqual(loaded(2), 2 + 2j)
def test_complex_parse(self): def test_complex_parse(self):
@ -65,14 +67,19 @@ class TestComplex(JitTestCase):
self.checkScript(fn, (t1, t2, 2)) self.checkScript(fn, (t1, t2, 2))
def test_complex_constants_and_ops(self): def test_complex_constants_and_ops(self):
vals = ([0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2] vals = (
+ [10.0 ** i for i in range(2)] + [-(10.0 ** i) for i in range(2)]) [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
+ [10.0**i for i in range(2)]
+ [-(10.0**i) for i in range(2)]
)
complex_vals = tuple(complex(x, y) for x, y in product(vals, vals)) complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
funcs_template = dedent(''' funcs_template = dedent(
"""
def func(a: complex): def func(a: complex):
return cmath.{func_or_const}(a) return cmath.{func_or_const}(a)
''') """
)
def checkCmath(func_name, funcs_template=funcs_template): def checkCmath(func_name, funcs_template=funcs_template):
funcs_str = funcs_template.format(func_or_const=func_name) funcs_str = funcs_template.format(func_or_const=func_name)
@ -80,11 +87,13 @@ class TestComplex(JitTestCase):
execWrapper(funcs_str, globals(), scope) execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str) cu = torch.jit.CompilationUnit(funcs_str)
f_script = cu.func f_script = cu.func
f = scope['func'] f = scope["func"]
if func_name in ['isinf', 'isnan', 'isfinite']: if func_name in ["isinf", "isnan", "isfinite"]:
new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')]) new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals)) final_vals = tuple(
complex(x, y) for x, y in product(new_vals, new_vals)
)
else: else:
final_vals = complex_vals final_vals = complex_vals
@ -107,8 +116,27 @@ class TestComplex(JitTestCase):
msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}" msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
self.assertEqual(res_python, res_script, msg=msg) self.assertEqual(res_python, res_script, msg=msg)
unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh', unary_ops = [
'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite'] "log",
"log10",
"sqrt",
"exp",
"sin",
"cos",
"asin",
"acos",
"atan",
"sinh",
"cosh",
"tanh",
"asinh",
"acosh",
"atanh",
"phase",
"isinf",
"isnan",
"isfinite",
]
# --- Unary ops --- # --- Unary ops ---
for op in unary_ops: for op in unary_ops:
@ -118,7 +146,7 @@ class TestComplex(JitTestCase):
return abs(x) return abs(x)
for val in complex_vals: for val in complex_vals:
self.checkScript(fn, (val, )) self.checkScript(fn, (val,))
def pow_complex_float(x: complex, y: float): def pow_complex_float(x: complex, y: float):
return pow(x, y) return pow(x, y)
@ -126,7 +154,6 @@ class TestComplex(JitTestCase):
def pow_float_complex(x: float, y: complex): def pow_float_complex(x: float, y: complex):
return pow(x, y) return pow(x, y)
self.checkScript(pow_float_complex, (2, 3j)) self.checkScript(pow_float_complex, (2, 3j))
self.checkScript(pow_complex_float, (3j, 2)) self.checkScript(pow_complex_float, (3j, 2))
@ -135,7 +162,7 @@ class TestComplex(JitTestCase):
for x, y in zip(complex_vals, complex_vals): for x, y in zip(complex_vals, complex_vals):
# Reference: https://github.com/pytorch/pytorch/issues/54622 # Reference: https://github.com/pytorch/pytorch/issues/54622
if (x == 0): if x == 0:
continue continue
self.checkScript(pow_complex_complex, (x, y)) self.checkScript(pow_complex_complex, (x, y))
@ -143,16 +170,25 @@ class TestComplex(JitTestCase):
# --- Binary op --- # --- Binary op ---
def rect_fn(x: float, y: float): def rect_fn(x: float, y: float):
return cmath.rect(x, y) return cmath.rect(x, y)
for x, y in product(vals, vals):
self.checkScript(rect_fn, (x, y, ))
func_constants_template = dedent(''' for x, y in product(vals, vals):
self.checkScript(
rect_fn,
(
x,
y,
),
)
func_constants_template = dedent(
"""
def func(): def func():
return cmath.{func_or_const} return cmath.{func_or_const}
''') """
float_consts = ['pi', 'e', 'tau', 'inf', 'nan'] )
complex_consts = ['infj', 'nanj'] float_consts = ["pi", "e", "tau", "inf", "nan"]
for x in (float_consts + complex_consts): complex_consts = ["infj", "nanj"]
for x in float_consts + complex_consts:
checkCmath(x, funcs_template=func_constants_template) checkCmath(x, funcs_template=func_constants_template)
def test_infj_nanj_pickle(self): def test_infj_nanj_pickle(self):
@ -177,77 +213,293 @@ class TestComplex(JitTestCase):
def fn_int(real: int, img: int): def fn_int(real: int, img: int):
return complex(real, img) return complex(real, img)
self.checkScript(fn_int, (0, 0, )) self.checkScript(
self.checkScript(fn_int, (-1234, 0, )) fn_int,
self.checkScript(fn_int, (0, -1256, )) (
self.checkScript(fn_int, (-167, -1256, )) 0,
0,
),
)
self.checkScript(
fn_int,
(
-1234,
0,
),
)
self.checkScript(
fn_int,
(
0,
-1256,
),
)
self.checkScript(
fn_int,
(
-167,
-1256,
),
)
def fn_float(real: float, img: float): def fn_float(real: float, img: float):
return complex(real, img) return complex(real, img)
self.checkScript(fn_float, (0.0, 0.0, )) self.checkScript(
self.checkScript(fn_float, (-1234.78, 0, )) fn_float,
self.checkScript(fn_float, (0, 56.18, )) (
self.checkScript(fn_float, (-1.9, -19.8, )) 0.0,
0.0,
),
)
self.checkScript(
fn_float,
(
-1234.78,
0,
),
)
self.checkScript(
fn_float,
(
0,
56.18,
),
)
self.checkScript(
fn_float,
(
-1.9,
-19.8,
),
)
def fn_bool(real: bool, img: bool): def fn_bool(real: bool, img: bool):
return complex(real, img) return complex(real, img)
self.checkScript(fn_bool, (True, True, )) self.checkScript(
self.checkScript(fn_bool, (False, False, )) fn_bool,
self.checkScript(fn_bool, (False, True, )) (
self.checkScript(fn_bool, (True, False, )) True,
True,
),
)
self.checkScript(
fn_bool,
(
False,
False,
),
)
self.checkScript(
fn_bool,
(
False,
True,
),
)
self.checkScript(
fn_bool,
(
True,
False,
),
)
def fn_bool_int(real: bool, img: int): def fn_bool_int(real: bool, img: int):
return complex(real, img) return complex(real, img)
self.checkScript(fn_bool_int, (True, 0, )) self.checkScript(
self.checkScript(fn_bool_int, (False, 0, )) fn_bool_int,
self.checkScript(fn_bool_int, (False, -1, )) (
self.checkScript(fn_bool_int, (True, 3, )) True,
0,
),
)
self.checkScript(
fn_bool_int,
(
False,
0,
),
)
self.checkScript(
fn_bool_int,
(
False,
-1,
),
)
self.checkScript(
fn_bool_int,
(
True,
3,
),
)
def fn_int_bool(real: int, img: bool): def fn_int_bool(real: int, img: bool):
return complex(real, img) return complex(real, img)
self.checkScript(fn_int_bool, (0, True, )) self.checkScript(
self.checkScript(fn_int_bool, (0, False, )) fn_int_bool,
self.checkScript(fn_int_bool, (-3, True, )) (
self.checkScript(fn_int_bool, (6, False, )) 0,
True,
),
)
self.checkScript(
fn_int_bool,
(
0,
False,
),
)
self.checkScript(
fn_int_bool,
(
-3,
True,
),
)
self.checkScript(
fn_int_bool,
(
6,
False,
),
)
def fn_bool_float(real: bool, img: float): def fn_bool_float(real: bool, img: float):
return complex(real, img) return complex(real, img)
self.checkScript(fn_bool_float, (True, 0.0, )) self.checkScript(
self.checkScript(fn_bool_float, (False, 0.0, )) fn_bool_float,
self.checkScript(fn_bool_float, (False, -1.0, )) (
self.checkScript(fn_bool_float, (True, 3.0, )) True,
0.0,
),
)
self.checkScript(
fn_bool_float,
(
False,
0.0,
),
)
self.checkScript(
fn_bool_float,
(
False,
-1.0,
),
)
self.checkScript(
fn_bool_float,
(
True,
3.0,
),
)
def fn_float_bool(real: float, img: bool): def fn_float_bool(real: float, img: bool):
return complex(real, img) return complex(real, img)
self.checkScript(fn_float_bool, (0.0, True, )) self.checkScript(
self.checkScript(fn_float_bool, (0.0, False, )) fn_float_bool,
self.checkScript(fn_float_bool, (-3.0, True, )) (
self.checkScript(fn_float_bool, (6.0, False, )) 0.0,
True,
),
)
self.checkScript(
fn_float_bool,
(
0.0,
False,
),
)
self.checkScript(
fn_float_bool,
(
-3.0,
True,
),
)
self.checkScript(
fn_float_bool,
(
6.0,
False,
),
)
def fn_float_int(real: float, img: int): def fn_float_int(real: float, img: int):
return complex(real, img) return complex(real, img)
self.checkScript(fn_float_int, (0.0, 1, )) self.checkScript(
self.checkScript(fn_float_int, (0.0, -1, )) fn_float_int,
self.checkScript(fn_float_int, (1.8, -3, )) (
self.checkScript(fn_float_int, (2.7, 8, )) 0.0,
1,
),
)
self.checkScript(
fn_float_int,
(
0.0,
-1,
),
)
self.checkScript(
fn_float_int,
(
1.8,
-3,
),
)
self.checkScript(
fn_float_int,
(
2.7,
8,
),
)
def fn_int_float(real: int, img: float): def fn_int_float(real: int, img: float):
return complex(real, img) return complex(real, img)
self.checkScript(fn_int_float, (1, 0.0, )) self.checkScript(
self.checkScript(fn_int_float, (-1, 1.7, )) fn_int_float,
self.checkScript(fn_int_float, (-3, 0.0, )) (
self.checkScript(fn_int_float, (2, -8.9, )) 1,
0.0,
),
)
self.checkScript(
fn_int_float,
(
-1,
1.7,
),
)
self.checkScript(
fn_int_float,
(
-3,
0.0,
),
)
self.checkScript(
fn_int_float,
(
2,
-8.9,
),
)
def test_torch_complex_constructor_with_tensor(self): def test_torch_complex_constructor_with_tensor(self):
tensors = ([torch.rand(1), torch.randint(-5, 5, (1, )), torch.tensor([False])]) tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
def fn_tensor_float(real, img: float): def fn_tensor_float(real, img: float):
return complex(real, img) return complex(real, img)
@ -280,7 +532,13 @@ class TestComplex(JitTestCase):
return complex(real, img) + complex(2) return complex(real, img) + complex(2)
for x, y in product(tensors, tensors): for x, y in product(tensors, tensors):
self.checkScript(fn_tensor_tensor, (x, y, )) self.checkScript(
fn_tensor_tensor,
(
x,
y,
),
)
def test_comparison_ops(self): def test_comparison_ops(self):
def fn1(a: complex, b: complex): def fn1(a: complex, b: complex):
@ -316,7 +574,7 @@ class TestComplex(JitTestCase):
def fn(x: List[complex]): def fn(x: List[complex]):
return sum(x) return sum(x)
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), )) self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
def test_tensor_attributes(self): def test_tensor_attributes(self):
def tensor_real(x): def tensor_real(x):
@ -326,8 +584,8 @@ class TestComplex(JitTestCase):
return x.imag return x.imag
t = torch.randn(2, 3, dtype=torch.cdouble) t = torch.randn(2, 3, dtype=torch.cdouble)
self.checkScript(tensor_real, (t, )) self.checkScript(tensor_real, (t,))
self.checkScript(tensor_imag, (t, )) self.checkScript(tensor_imag, (t,))
def test_binary_op_complex_tensor(self): def test_binary_op_complex_tensor(self):
def mul(x: complex, y: torch.Tensor): def mul(x: complex, y: torch.Tensor):
@ -350,7 +608,7 @@ class TestComplex(JitTestCase):
ops = [mul, add, eq, ne, sub, div] ops = [mul, add, eq, ne, sub, div]
for shape in [(1, ), (2, 2)]: for shape in [(1,), (2, 2)]:
x = 0.71 + 0.71j x = 0.71 + 0.71j
y = torch.randn(shape, dtype=torch.cfloat) y = torch.randn(shape, dtype=torch.cfloat)
for op in ops: for op in ops:

View File

@ -10,18 +10,29 @@ import torch
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, enable_profiling_mode from torch.testing._internal.common_utils import (
from torch.testing._internal.jit_metaprogramming_utils import try_get_nn_module_compiled_mod_and_inputs, \ IS_FBCODE,
get_nn_mod_test_name, get_all_nn_module_tests, nn_functional_tests, get_nn_functional_compiled_fn_and_inputs run_tests,
from torch.testing._internal.common_utils import run_tests, set_default_dtype, suppress_warnings, IS_FBCODE set_default_dtype,
suppress_warnings,
)
from torch.testing._internal.jit_metaprogramming_utils import (
get_all_nn_module_tests,
get_nn_functional_compiled_fn_and_inputs,
get_nn_mod_test_name,
nn_functional_tests,
try_get_nn_module_compiled_mod_and_inputs,
)
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
def num_ifs_loops(graph): def num_ifs_loops(graph):
graph_str = str(graph) graph_str = str(graph)
# only look at body of graph # only look at body of graph
graph_body = graph_str[0:graph_str.find("return")] graph_body = graph_str[0 : graph_str.find("return")]
return graph_body.count("prim::Loop") + graph_body.count("prim::If") return graph_body.count("prim::Loop") + graph_body.count("prim::If")
def num_non_tensor_nodes(block): def num_non_tensor_nodes(block):
num_non_tensor = 0 num_non_tensor = 0
for node in block.nodes(): for node in block.nodes():
@ -40,6 +51,7 @@ def num_non_tensor_nodes(block):
num_non_tensor += int(not tensor_out) num_non_tensor += int(not tensor_out)
return num_non_tensor return num_non_tensor
class TestComplexity(JitTestCase): class TestComplexity(JitTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
@ -90,5 +102,6 @@ class TestComplexity(JitTestCase):
for line in stats: for line in stats:
print(line) print(line)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests() run_tests()

View File

@ -2,16 +2,18 @@
import os import os
import sys import sys
import unittest
from itertools import product from itertools import product
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing import FileCheck from torch.testing import FileCheck
import unittest
try: try:
import torchvision import torchvision
HAS_TORCHVISION = True HAS_TORCHVISION = True
except ImportError: except ImportError:
HAS_TORCHVISION = False HAS_TORCHVISION = False
@ -22,10 +24,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
activations = [ activations = [
F.celu, F.celu,
@ -41,6 +45,7 @@ activations = [
F.silu, F.silu,
] ]
class TestFunctionalToInplaceActivation(JitTestCase): class TestFunctionalToInplaceActivation(JitTestCase):
def test_check_no_type_promotion(self): def test_check_no_type_promotion(self):
dtypes = [ dtypes = [
@ -67,6 +72,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
def test_functional_to_inplace_activation(self): def test_functional_to_inplace_activation(self):
for activation in activations: for activation in activations:
def test_basic(x): def test_basic(x):
y = x + 1 y = x + 1
z = activation(y) z = activation(y)
@ -76,7 +82,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
self.run_pass("inline", fn.graph) self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph) self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
self.run_pass('functional_to_inplace_activation', fn.graph) self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph) FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
inp = torch.rand([2, 2]) inp = torch.rand([2, 2])
@ -91,7 +97,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return z return z
fn = torch.jit.script(test1) fn = torch.jit.script(test1)
self.run_pass('functional_to_inplace_activation', fn.graph) self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::sigmoid_").run(fn.graph) FileCheck().check_not("aten::sigmoid_").run(fn.graph)
# inplace conversion should not happen because y is alias # inplace conversion should not happen because y is alias
@ -102,7 +108,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return z return z
fn = torch.jit.script(test2) fn = torch.jit.script(test2)
self.run_pass('functional_to_inplace_activation', fn.graph) self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph) FileCheck().check_not("aten::relu_").run(fn.graph)
# inplace conversion should not happen because self.x is # inplace conversion should not happen because self.x is
@ -117,22 +123,33 @@ class TestFunctionalToInplaceActivation(JitTestCase):
return y return y
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval()) fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
self.run_pass('functional_to_inplace_activation', fn.graph) self.run_pass("functional_to_inplace_activation", fn.graph)
FileCheck().check_not("aten::relu_").run(fn.graph) FileCheck().check_not("aten::relu_").run(fn.graph)
@skipIfNoTorchVision @skipIfNoTorchVision
def test_resnet18_correctness(self): def test_resnet18_correctness(self):
model = torchvision.models.resnet18() model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224 (
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W) inp = torch.randn(N, C, H, W)
self.run_pass('functional_to_inplace_activation', frozen_model.graph) self.run_pass("functional_to_inplace_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp)) self.assertEqual(model(inp), frozen_model(inp))
class TestInplaceToFunctionalActivation(JitTestCase): class TestInplaceToFunctionalActivation(JitTestCase):
def test_inplace_to_functional_activation(self): def test_inplace_to_functional_activation(self):
for activation in activations: for activation in activations:
def test_basic(x): def test_basic(x):
y = x + 1 y = x + 1
activation(y, inplace=True) activation(y, inplace=True)
@ -142,7 +159,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
self.run_pass("inline", fn.graph) self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph) self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph) FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph) self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph) FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph) FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
@ -151,6 +168,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
torch.sigmoid_, torch.sigmoid_,
torch.tanh_, torch.tanh_,
]: ]:
def test_basic(x): def test_basic(x):
y = x + 1 y = x + 1
activation(y) activation(y)
@ -160,7 +178,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
self.run_pass("inline", fn.graph) self.run_pass("inline", fn.graph)
self.run_pass("constant_propagation", fn.graph) self.run_pass("constant_propagation", fn.graph)
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph) FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
self.run_pass('inplace_to_functional_activation', fn.graph) self.run_pass("inplace_to_functional_activation", fn.graph)
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph) FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph) FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
@ -171,7 +189,17 @@ class TestInplaceToFunctionalActivation(JitTestCase):
def test_resnet18_correctness(self): def test_resnet18_correctness(self):
model = torchvision.models.resnet18() model = torchvision.models.resnet18()
frozen_model = torch.jit.freeze(torch.jit.script(model.eval())) frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
N, C, H, W, = 10, 3, 224, 224 (
N,
C,
H,
W,
) = (
10,
3,
224,
224,
)
inp = torch.randn(N, C, H, W) inp = torch.randn(N, C, H, W)
self.run_pass('inplace_to_functional_activation', frozen_model.graph) self.run_pass("inplace_to_functional_activation", frozen_model.graph)
self.assertEqual(model(inp), frozen_model(inp)) self.assertEqual(model(inp), frozen_model(inp))

View File

@ -1,16 +1,21 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import gc
import os import os
import sys import sys
import gc
import unittest import unittest
from typing import NamedTuple
import torch import torch
from typing import NamedTuple
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf, NoTest, TEST_CUDA
from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
NoTest,
skipCUDANonDefaultStreamIf,
skipIfRocm,
TEST_CUDA,
)
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -18,7 +23,7 @@ sys.path.append(pytorch_test_dir)
# If GPU is not available, then do not run the tests # If GPU is not available, then do not run the tests
if not TEST_CUDA: if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr) print("CUDA not available, skipping tests", file=sys.stderr)
JitTestCase = NoTest # noqa: F811 JitTestCase = NoTest # noqa: F811
TEST_LARGE_TENSOR = TEST_CUDA TEST_LARGE_TENSOR = TEST_CUDA
@ -36,10 +41,12 @@ if __name__ == "__main__":
"instead." "instead."
) )
class TestCUDA(JitTestCase): class TestCUDA(JitTestCase):
""" """
A suite of tests for the CUDA API in TorchScript. A suite of tests for the CUDA API in TorchScript.
""" """
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -54,10 +61,10 @@ class TestCUDA(JitTestCase):
def test_device_synchronize(): def test_device_synchronize():
prev_current_device_index = torch.cuda.current_device() prev_current_device_index = torch.cuda.current_device()
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.synchronize('cuda') torch.cuda.synchronize("cuda")
torch.cuda.synchronize('cuda:0') torch.cuda.synchronize("cuda:0")
torch.cuda.synchronize(0) torch.cuda.synchronize(0)
torch.cuda.synchronize(torch.device('cuda:1')) torch.cuda.synchronize(torch.device("cuda:1"))
after_current_device_index = torch.cuda.current_device() after_current_device_index = torch.cuda.current_device()
# Check if the current device index is same as the device index before # Check if the current device index is same as the device index before
@ -66,7 +73,7 @@ class TestCUDA(JitTestCase):
@torch.jit.script @torch.jit.script
def test_multi_device_synchronize(): def test_multi_device_synchronize():
torch.cuda.synchronize(torch.device('cuda:0')) torch.cuda.synchronize(torch.device("cuda:0"))
prev_current_device_index = torch.cuda.current_device() prev_current_device_index = torch.cuda.current_device()
torch.cuda.synchronize(1) torch.cuda.synchronize(1)
after_current_device_index = torch.cuda.current_device() after_current_device_index = torch.cuda.current_device()
@ -76,11 +83,9 @@ class TestCUDA(JitTestCase):
return prev_current_device_index == after_current_device_index return prev_current_device_index == after_current_device_index
self.assertTrue(test_device_synchronize) self.assertTrue(test_device_synchronize)
FileCheck().check("cuda::synchronize(") \ FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph)
.run(test_device_synchronize.graph)
self.assertTrue(test_multi_device_synchronize) self.assertTrue(test_multi_device_synchronize)
FileCheck().check("cuda::synchronize(") \ FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph)
.run(test_multi_device_synchronize.graph)
def test_stream_args(self): def test_stream_args(self):
# Test stream creation with default arguments # Test stream creation with default arguments
@ -165,7 +170,6 @@ class TestCUDA(JitTestCase):
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@skipCUDANonDefaultStreamIf(True) @skipCUDANonDefaultStreamIf(True)
def test_streams_and_events(self): def test_streams_and_events(self):
# Test default_stream API by passing device ID as an argument and # Test default_stream API by passing device ID as an argument and
# and check if the stream device index matches with the device ID # and check if the stream device index matches with the device ID
@torch.jit.script @torch.jit.script
@ -182,14 +186,14 @@ class TestCUDA(JitTestCase):
# This test checks for the default stream ID is set to 0 on the device # This test checks for the default stream ID is set to 0 on the device
@torch.jit.script @torch.jit.script
def test_default_streams(): def test_default_streams():
s0 = torch.cuda.default_stream(torch.device('cuda:0')) s0 = torch.cuda.default_stream(torch.device("cuda:0"))
s1 = torch.cuda.default_stream(torch.device('cuda:1')) s1 = torch.cuda.default_stream(torch.device("cuda:1"))
d = torch.device('cuda:1') d = torch.device("cuda:1")
# Check the current stream id and default id are same # Check the current stream id and default id are same
# on the current device. The current device id by default is 0 # on the current device. The current device id by default is 0
s2 = torch.cuda.current_stream(torch.device('cuda:0')) s2 = torch.cuda.current_stream(torch.device("cuda:0"))
check_s2 = s2.id() == s0.id() check_s2 = s2.id() == s0.id()
check_d0 = torch.cuda.current_device() == s2.device_index() check_d0 = torch.cuda.current_device() == s2.device_index()
@ -203,9 +207,25 @@ class TestCUDA(JitTestCase):
# Check if the current device was reset to 0 # Check if the current device was reset to 0
is_device_d0 = torch.cuda.current_device() == s2.device_index() is_device_d0 = torch.cuda.current_device() == s2.device_index()
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0 return (
s0.device_index(),
s1.device_index(),
check_s2,
check_s3,
check_d0,
check_d1,
is_device_d0,
)
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams() (
d0,
d1,
check_s2,
check_s3,
check_d0,
check_d1,
is_device_d0,
) = test_default_streams()
self.assertEqual(d0, 0) self.assertEqual(d0, 0)
self.assertEqual(d1, 1) self.assertEqual(d1, 1)
@ -228,12 +248,21 @@ class TestCUDA(JitTestCase):
with torch.cuda.stream(None): with torch.cuda.stream(None):
cur_device_index = torch.cuda.current_device() cur_device_index = torch.cuda.current_device()
is_device_index_same = cur_device_index == device_index is_device_index_same = cur_device_index == device_index
is_current_stream_same = torch.cuda.current_stream(device).id() == current_stream.id() is_current_stream_same = (
is_default_stream_same = torch.cuda.default_stream(device).id() == default_stream.id() torch.cuda.current_stream(device).id() == current_stream.id()
)
is_default_stream_same = (
torch.cuda.default_stream(device).id() == default_stream.id()
)
# Check if the device index, current stream and default streams have not changed # Check if the device index, current stream and default streams have not changed
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same are_streams_same = (
is_device_index_same
and is_current_stream_same
and is_default_stream_same
)
return are_streams_same return are_streams_same
self.assertTrue(test_set_none_stream()) self.assertTrue(test_set_none_stream())
# This test checks if the Device Context manager is a no op # This test checks if the Device Context manager is a no op
@ -246,6 +275,7 @@ class TestCUDA(JitTestCase):
# Check if the current device is the same # Check if the current device is the same
is_device_same = torch.cuda.current_device() == device_index is_device_same = torch.cuda.current_device() == device_index
return is_device_same return is_device_same
self.assertTrue(test_set_device_none()) self.assertTrue(test_set_device_none())
# Check if a CUDA JIT stream is created # Check if a CUDA JIT stream is created
@ -260,15 +290,15 @@ class TestCUDA(JitTestCase):
# Class used to store results for the test: test_get_stream. # Class used to store results for the test: test_get_stream.
class Result(NamedTuple): class Result(NamedTuple):
t1 : torch.Tensor t1: torch.Tensor
t2 : torch.Tensor t2: torch.Tensor
is_current_and_default_stream_same : bool is_current_and_default_stream_same: bool
is_default_and_user_stream_not_same : bool is_default_and_user_stream_not_same: bool
is_stream_set : bool is_stream_set: bool
is_stream_reset : bool is_stream_reset: bool
default_stream_query : bool default_stream_query: bool
default_stream_id : int default_stream_id: int
user_stream_id : int user_stream_id: int
# The test aims at checking different stream proporties. # The test aims at checking different stream proporties.
@torch.jit.script @torch.jit.script
@ -280,15 +310,23 @@ class TestCUDA(JitTestCase):
user_stream = torch.cuda.Stream() user_stream = torch.cuda.Stream()
# Check if the current and default streams are the same on the device # Check if the current and default streams are the same on the device
is_current_and_default_stream_same = current_stream.id() == default_stream.id() is_current_and_default_stream_same = (
current_stream.id() == default_stream.id()
)
# Check if user stream and default stream are not the same on the device # Check if user stream and default stream are not the same on the device
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id() is_default_and_user_stream_not_same = (
default_stream.id() != user_stream.id()
)
with torch.cuda.stream(user_stream): with torch.cuda.stream(user_stream):
is_stream_set = torch.cuda.current_stream(device).id() == user_stream.id() is_stream_set = (
torch.cuda.current_stream(device).id() == user_stream.id()
)
# Check if the stream was reset to current_stream # Check if the stream was reset to current_stream
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id() is_stream_reset = (
torch.cuda.current_stream(device).id() == current_stream.id()
)
tensor1 = torch.rand(10000, 10000, device="cuda") tensor1 = torch.rand(10000, 10000, device="cuda")
tensor2 = torch.mm(tensor1, tensor1).to("cuda") tensor2 = torch.mm(tensor1, tensor1).to("cuda")
@ -297,9 +335,16 @@ class TestCUDA(JitTestCase):
# Capture all the results in the class Result # Capture all the results in the class Result
res = Result( res = Result(
tensor1, tensor2, is_current_and_default_stream_same, tensor1,
is_default_and_user_stream_not_same, is_stream_set, tensor2,
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id()) is_current_and_default_stream_same,
is_default_and_user_stream_not_same,
is_stream_set,
is_stream_reset,
default_stream_query,
default_stream.id(),
user_stream.id(),
)
return res return res
result = test_get_stream() result = test_get_stream()
@ -310,8 +355,12 @@ class TestCUDA(JitTestCase):
self.assertTrue(result.is_stream_set) self.assertTrue(result.is_stream_set)
self.assertTrue(result.is_stream_reset) self.assertTrue(result.is_stream_reset)
self.assertTrue(result.default_stream_query) self.assertTrue(result.default_stream_query)
self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0 self.assertEqual(
self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero result.default_stream_id, 0
) # Check if the default stream ID is always 0
self.assertNotEqual(
result.user_stream_id, 0
) # Check if the user stream is always non zero
# Test the stream context manager. This test checks if the stream is switched # Test the stream context manager. This test checks if the stream is switched
# to the user stream on using the stream context manager. # to the user stream on using the stream context manager.
@ -329,14 +378,20 @@ class TestCUDA(JitTestCase):
# Wait for B to be computed # Wait for B to be computed
user_stream.synchronize() user_stream.synchronize()
# Check if the stream has been reset on the current device # Check if the stream has been reset on the current device
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id() is_stream_reset = (
torch.cuda.current_stream(device).id() == current_stream.id()
)
return A, B, check, is_stream_reset return A, B, check, is_stream_reset
A, B, is_stream_set, is_stream_reset = test_stream_context() A, B, is_stream_set, is_stream_reset = test_stream_context()
self.assertEqual(torch.matmul(A, A), B) self.assertEqual(torch.matmul(A, A), B)
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!") self.assertTrue(
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!") is_stream_set, "Error: Current stream was not set to user stream!"
)
self.assertTrue(
is_stream_reset, "Error: The stream was not restored to previous stream!"
)
# Test multiple nested streams. Check if the operations are computed as expected on the streams # Test multiple nested streams. Check if the operations are computed as expected on the streams
# This test has been adapted from the eager mode tests available at test/test_cuda.py # This test has been adapted from the eager mode tests available at test/test_cuda.py
@ -372,11 +427,24 @@ class TestCUDA(JitTestCase):
# Check if the stream and device has been restored to previous stream and device # Check if the stream and device has been restored to previous stream and device
is_device_current = torch.cuda.current_device() == prev_device_index is_device_current = torch.cuda.current_device() == prev_device_index
is_stream_current = torch.cuda.current_stream(device).id() == prev_current_stream.id() is_stream_current = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current check_stream = (
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current is_stream_s1
and is_stream_s2
and is_stream_s1_after
and is_stream_current
)
check_device = (
is_device_s1
and is_device_s2
and is_device_s1_after
and is_device_current
)
return A, B, C, D, check_stream, check_device return A, B, C, D, check_stream, check_device
A, B, C, D, check_stream, check_device = test_multiple_stream() A, B, C, D, check_stream, check_device = test_multiple_stream()
self.assertEqual(torch.matmul(A, A), C) self.assertEqual(torch.matmul(A, A), C)
@ -401,7 +469,9 @@ class TestCUDA(JitTestCase):
B = torch.mm(A, A).to("cuda") B = torch.mm(A, A).to("cuda")
s1.record_event(event) s1.record_event(event)
# Check if the current_stream is reset # Check if the current_stream is reset
is_current_stream_1 = torch.cuda.current_stream(device).id() == prev_current_stream.id() is_current_stream_1 = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
# Wait for ops on s1 to be computed # Wait for ops on s1 to be computed
s2.wait_event(event) s2.wait_event(event)
with torch.cuda.stream(s2): with torch.cuda.stream(s2):
@ -410,9 +480,16 @@ class TestCUDA(JitTestCase):
# Wait for C to be computed # Wait for C to be computed
s2.synchronize() s2.synchronize()
# Check if the current_stream is reset # Check if the current_stream is reset
is_current_stream_2 = torch.cuda.current_stream(device).id() == prev_current_stream.id() is_current_stream_2 = (
torch.cuda.current_stream(device).id() == prev_current_stream.id()
)
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2 check_stream = (
is_current_stream_1
and is_current_stream_2
and is_stream_s1
and is_stream_s2
)
return A, B, C, check_stream return A, B, C, check_stream
A, B, C, check_stream = test_data_dependency_between_streams() A, B, C, check_stream = test_data_dependency_between_streams()
@ -425,6 +502,7 @@ class TestCUDA(JitTestCase):
def test_simple_event(): def test_simple_event():
e = torch.cuda.Event(True, False, False) e = torch.cuda.Event(True, False, False)
return e is not None return e is not None
self.assertTrue(test_simple_event(), "Could not create CUDA Event!") self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
# Record the CUDA event for operation torch.mm on the current stream # Record the CUDA event for operation torch.mm on the current stream
@ -474,6 +552,7 @@ class TestCUDA(JitTestCase):
# not necessary to check e_tik and e_tok, as elapsed_time would throw # not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise. # exception if otherwise.
return e_tik.elapsed_time(e_tok) return e_tik.elapsed_time(e_tok)
self.assertGreater(test_stream_synchronize(), 0) self.assertGreater(test_stream_synchronize(), 0)
# Test event synchronization for the event that records a stream doing # Test event synchronization for the event that records a stream doing
@ -536,12 +615,13 @@ class TestCUDA(JitTestCase):
# not necessary to check e_tik and e_tok, as elapsed_time would throw # not necessary to check e_tik and e_tok, as elapsed_time would throw
# exception if otherwise. # exception if otherwise.
return e_tik.elapsed_time(e_tok) return e_tik.elapsed_time(e_tok)
self.assertGreater(test_event_wait(), 0) self.assertGreater(test_event_wait(), 0)
# Test for stream wait_event. Checks if the stream waits on the event # Test for stream wait_event. Checks if the stream waits on the event
@torch.jit.script @torch.jit.script
def test_wait_event(): def test_wait_event():
d1 = torch.device('cuda:1') d1 = torch.device("cuda:1")
with torch.cuda.device(d1): with torch.cuda.device(d1):
s0 = torch.cuda.current_stream(d1) s0 = torch.cuda.current_stream(d1)
@ -550,11 +630,12 @@ class TestCUDA(JitTestCase):
e0 = torch.cuda.Event(False, False, False) e0 = torch.cuda.Event(False, False, False)
s0.record_event(e0) s0.record_event(e0)
s1 = torch.cuda.current_stream(torch.device('cuda:0')) s1 = torch.cuda.current_stream(torch.device("cuda:0"))
s1.wait_event(e0) s1.wait_event(e0)
s1.synchronize() s1.synchronize()
return e0.query() and s0.query() and s1.query() return e0.query() and s0.query() and s1.query()
self.assertTrue(test_wait_event()) self.assertTrue(test_wait_event())
# Test if a scripted module with cuda streams can be saved, loaded and executed # Test if a scripted module with cuda streams can be saved, loaded and executed

View File

@ -11,33 +11,37 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
def canonical(graph): def canonical(graph):
return torch._C._jit_pass_canonicalize(graph).str(False) return torch._C._jit_pass_canonicalize(graph).str(False)
class TestCustomOperators(JitTestCase):
class TestCustomOperators(JitTestCase):
def test_dynamic_op_registry(self): def test_dynamic_op_registry(self):
from torch._ops import _OpNamespace from torch._ops import _OpNamespace
self.assertTrue(hasattr(torch, 'ops'))
if '_test' in torch.ops.__dict__: self.assertTrue(hasattr(torch, "ops"))
torch.ops.__dict__.pop('_test')
if "_test" in torch.ops.__dict__:
torch.ops.__dict__.pop("_test")
# Don't use `hasattr()` because it will call `__getattr__`. # Don't use `hasattr()` because it will call `__getattr__`.
self.assertNotIn('_test', torch.ops.__dict__) self.assertNotIn("_test", torch.ops.__dict__)
torch.ops._test torch.ops._test
self.assertIn('_test', torch.ops.__dict__) self.assertIn("_test", torch.ops.__dict__)
self.assertEqual(type(torch.ops._test), _OpNamespace) self.assertEqual(type(torch.ops._test), _OpNamespace)
self.assertNotIn('leaky_relu', torch.ops._test.__dict__) self.assertNotIn("leaky_relu", torch.ops._test.__dict__)
op = torch.ops._test.leaky_relu op = torch.ops._test.leaky_relu
self.assertTrue(callable(op)) self.assertTrue(callable(op))
self.assertIn('leaky_relu', torch.ops._test.__dict__) self.assertIn("leaky_relu", torch.ops._test.__dict__)
op2 = torch.ops._test.leaky_relu op2 = torch.ops._test.leaky_relu
self.assertEqual(op, op2) self.assertEqual(op, op2)
@ -46,7 +50,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
AttributeError, AttributeError,
f"Invalid attribute '{attr}' for '_OpNamespace' '_test'", f"Invalid attribute '{attr}' for '_OpNamespace' '_test'",
"" "",
): ):
getattr(torch.ops._test, attr) getattr(torch.ops._test, attr)
@ -63,15 +67,13 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, RuntimeError,
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)", r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)",
"" "",
): ):
torch.ops.aten.relu(1, 2) torch.ops.aten.relu(1, 2)
def test_passing_too_few_args(self): def test_passing_too_few_args(self):
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, RuntimeError, r"aten::relu\(\) is missing value for argument 'self'.", ""
r"aten::relu\(\) is missing value for argument 'self'.",
""
): ):
torch.ops.aten.relu() torch.ops.aten.relu()
@ -79,7 +81,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, RuntimeError,
r"aten::type_as\(\) is missing value for argument 'other'.", r"aten::type_as\(\) is missing value for argument 'other'.",
"" "",
): ):
torch.ops.aten.type_as(torch.ones(5, 5)) torch.ops.aten.type_as(torch.ones(5, 5))
@ -87,7 +89,7 @@ class TestCustomOperators(JitTestCase):
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, RuntimeError,
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'", "Unknown keyword argument 'foo' for operator '_test::leaky_relu'",
"" "",
): ):
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5)) torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
@ -102,6 +104,7 @@ class TestCustomOperators(JitTestCase):
@torch.jit.script @torch.jit.script
def func(x): def func(x):
return torch.ops.aten.relu(x) return torch.ops.aten.relu(x)
input = torch.ones(5, 5) input = torch.ones(5, 5)
self.assertEqual(func(input), input.relu()) self.assertEqual(func(input), input.relu())
@ -110,28 +113,37 @@ class TestCustomOperators(JitTestCase):
func = torch.jit.trace(torch.ops.aten.relu, [input]) func = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertEqual(func(input), input.relu()) self.assertEqual(func(input), input.relu())
@unittest.skip("Need to figure out default dtype differences between fbcode and oss") @unittest.skip(
"Need to figure out default dtype differences between fbcode and oss"
)
def test_script_graph_for_custom_ops_matches_traced_graph(self): def test_script_graph_for_custom_ops_matches_traced_graph(self):
input = torch.ones(5, 5) input = torch.ones(5, 5)
trace = torch.jit.trace(torch.ops.aten.relu, [input]) trace = torch.jit.trace(torch.ops.aten.relu, [input])
self.assertExpectedInline(canonical(trace.graph), '''\ self.assertExpectedInline(
canonical(trace.graph),
"""\
graph(%0 : Float(5, 5)): graph(%0 : Float(5, 5)):
%1 : Float(5, 5) = aten::relu(%0) %1 : Float(5, 5) = aten::relu(%0)
return (%1) return (%1)
''') """,
)
def test_script_graph_contains_custom_op(self): def test_script_graph_contains_custom_op(self):
@torch.jit.script @torch.jit.script
def func(x): def func(x):
return torch.ops.aten.relu(x) return torch.ops.aten.relu(x)
self.assertExpectedInline(canonical(func.graph), '''\
self.assertExpectedInline(
canonical(func.graph),
"""\
graph(%x.1 : Tensor): graph(%x.1 : Tensor):
%1 : Tensor = aten::relu(%x.1) %1 : Tensor = aten::relu(%x.1)
return (%1) return (%1)
''') """,
)
def test_generic_list(self): def test_generic_list(self):
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello') self.assertEqual(torch.ops._test.get_first([["hello"]]), "hello")
# https://github.com/pytorch/pytorch/issues/80508 # https://github.com/pytorch/pytorch/issues/80508
def test_where_no_scalar(self): def test_where_no_scalar(self):

View File

@ -13,17 +13,21 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestDataParallel(JitTestCase): class TestDataParallel(JitTestCase):
class Mpy(torch.nn.Module): class Mpy(torch.nn.Module):
def __init__(self): def __init__(self):
super(TestDataParallel.Mpy, self).__init__() super(TestDataParallel.Mpy, self).__init__()
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2), self.m = nn.Sequential(
nn.ReLU(), nn.Linear(2, 2)) nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.ignore @torch.jit.ignore
def forward(self, input): def forward(self, input):
@ -50,13 +54,13 @@ class TestDataParallel(JitTestCase):
return self.m2(x) return self.m2(x)
class Msm(torch.jit.ScriptModule): class Msm(torch.jit.ScriptModule):
__constants__ = ["m"]
__constants__ = ['m']
def __init__(self): def __init__(self):
super(TestDataParallel.Msm, self).__init__() super(TestDataParallel.Msm, self).__init__()
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2), self.m = nn.Sequential(
nn.ReLU(), nn.Linear(2, 2)) nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
)
@torch.jit.script_method @torch.jit.script_method
def forward(self, input): def forward(self, input):
@ -140,7 +144,7 @@ class TestDataParallel(JitTestCase):
# Use .data here to avoid version counter bump. # Use .data here to avoid version counter bump.
# The graph created by the following forward will be wrong but # The graph created by the following forward will be wrong but
# we never backward through them so it's fine # we never backward through them so it's fine
p.data -= 1. * p.grad p.data -= 1.0 * p.grad
second_forward = module(x) second_forward = module(x)
# replica which is on the same GPU has a shallow copy of the original # replica which is on the same GPU has a shallow copy of the original

View File

@ -1,14 +1,16 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
# flake8: noqa # flake8: noqa
import sys
import unittest
from dataclasses import dataclass, field, InitVar from dataclasses import dataclass, field, InitVar
from enum import Enum
from typing import List, Optional
import torch
from hypothesis import given, settings, strategies as st from hypothesis import given, settings, strategies as st
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from typing import List, Optional
import sys
import torch
import unittest
from enum import Enum
# Example jittable dataclass # Example jittable dataclass
@dataclass(order=True) @dataclass(order=True)
@ -20,8 +22,8 @@ class Point:
def __post_init__(self): def __post_init__(self):
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5 self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
class MixupScheme(Enum):
class MixupScheme(Enum):
INPUT = ["input"] INPUT = ["input"]
MANIFOLD = [ MANIFOLD = [
@ -38,6 +40,7 @@ class MixupParams:
self.alpha = alpha self.alpha = alpha
self.scheme = scheme self.scheme = scheme
class MixupScheme2(Enum): class MixupScheme2(Enum):
A = 1 A = 1
B = 2 B = 2
@ -49,6 +52,7 @@ class MixupParams2:
self.alpha = alpha self.alpha = alpha
self.scheme = scheme self.scheme = scheme
@dataclass @dataclass
class MixupParams3: class MixupParams3:
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A): def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
@ -59,11 +63,11 @@ class MixupParams3:
# Make sure the Meta internal tooling doesn't raise an overflow error # Make sure the Meta internal tooling doesn't raise an overflow error
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False) NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
class TestDataclasses(JitTestCase):
class TestDataclasses(JitTestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
torch._C._jit_clear_class_registry() torch._C._jit_clear_class_registry()
def test_init_vars(self): def test_init_vars(self):
@torch.jit.script @torch.jit.script
@ -75,7 +79,9 @@ class TestDataclasses(JitTestCase):
norm: Optional[torch.Tensor] = None norm: Optional[torch.Tensor] = None
def __post_init__(self, norm_p: int): def __post_init__(self, norm_p: int):
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p) self.norm = (
torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
) ** (1 / norm_p)
def fn(x: float, y: float, p: int): def fn(x: float, y: float, p: int):
pt = Point2(x, y, p) pt = Point2(x, y, p)
@ -88,6 +94,7 @@ class TestDataclasses(JitTestCase):
@given(NonHugeFloats, NonHugeFloats) @given(NonHugeFloats, NonHugeFloats)
def test__post_init__(self, x, y): def test__post_init__(self, x, y):
P = torch.jit.script(Point) P = torch.jit.script(Point)
def fn(x: float, y: float): def fn(x: float, y: float):
pt = P(x, y) pt = P(x, y)
return pt.norm return pt.norm
@ -95,7 +102,9 @@ class TestDataclasses(JitTestCase):
self.checkScript(fn, [x, y]) self.checkScript(fn, [x, y])
@settings(deadline=None) @settings(deadline=None)
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)) @given(
st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
)
def test_comparators(self, pt1, pt2): def test_comparators(self, pt1, pt2):
x1, y1 = pt1 x1, y1 = pt1
x2, y2 = pt2 x2, y2 = pt2
@ -122,6 +131,7 @@ class TestDataclasses(JitTestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
torch.jit.script(Foo) torch.jit.script(Foo)
def fn(): def fn():
foo = Foo() foo = Foo()
return foo.x return foo.x
@ -137,7 +147,7 @@ class TestDataclasses(JitTestCase):
a: int a: int
b: int b: int
def __eq__(self, other: 'CustomEq') -> bool: def __eq__(self, other: "CustomEq") -> bool:
return self.a == other.a # ignore the b field return self.a == other.a # ignore the b field
def fn(a: int, b1: int, b2: int): def fn(a: int, b1: int, b2: int):
@ -154,9 +164,7 @@ class TestDataclasses(JitTestCase):
torch.jit.script(MixupParams2) # don't throw torch.jit.script(MixupParams2) # don't throw
def test_use_unregistered_dataclass_raises(self): def test_use_unregistered_dataclass_raises(self):
def f(a: MixupParams3): def f(a: MixupParams3):
return 0 return 0

View File

@ -1,12 +1,12 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from itertools import product
import unittest import unittest
from itertools import product
import torch import torch
from torch.jit._passes._property_propagation import apply_input_props_using_example
from torch.testing._internal.common_utils import TEST_CUDA from torch.testing._internal.common_utils import TEST_CUDA
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from torch.jit._passes._property_propagation import apply_input_props_using_example
try: try:
from torchvision import models from torchvision import models

View File

@ -7,19 +7,19 @@ from unittest.case import expectedFailure
import torch import torch
from torch import complex32, float32, float64, int32, int64 from torch import complex32, float32, float64, int32, int64
from torch.jit._passes import _property_propagation from torch.jit._passes import _property_propagation
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import ( from torch.testing._internal.common_methods_invocations import (
SampleInput, op_db,
sample_inputs_adaptive_avg_pool2d, sample_inputs_adaptive_avg_pool2d,
sample_inputs_conv2d, sample_inputs_conv2d,
SampleInput,
) )
from torch.testing._internal.common_utils import set_default_dtype, first_sample from torch.testing._internal.common_utils import first_sample, set_default_dtype
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
from torch.testing._internal.common_device_type import ( from torch.testing._internal.jit_utils import JitTestCase
ops,
instantiate_device_type_tests,
)
from torch.testing._internal.common_methods_invocations import op_db
""" """
Dtype Analysis relies on symbolic shape analysis, which is still in beta Dtype Analysis relies on symbolic shape analysis, which is still in beta
@ -274,7 +274,9 @@ class TestDtypeAnalysis(TestDtypeBase):
): ):
for dtype in (torch.int8, torch.float64): for dtype in (torch.int8, torch.float64):
# Gets default version for conv2d # Gets default version for conv2d
sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[-1] sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[
-1
]
input_args = [sample_input.input, *sample_input.args] input_args = [sample_input.input, *sample_input.args]
self.assert_dtype_equal_custom_args(fn, input_args) self.assert_dtype_equal_custom_args(fn, input_args)
@ -352,7 +354,9 @@ class TestDtypeCustomRules(TestDtypeBase):
# Run the Dtype Analysis # Run the Dtype Analysis
graph = traced_fn.graph # Note this is a cached graph graph = traced_fn.graph # Note this is a cached graph
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)] input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
input_tensors += [v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)] input_tensors += [
v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)
]
self.prop_dtype_on_graph(graph, input_tensors) self.prop_dtype_on_graph(graph, input_tensors)
self.assert_output_dtype_equal(expected_res, graph) self.assert_output_dtype_equal(expected_res, graph)

View File

@ -2,21 +2,24 @@
import os import os
import sys import sys
from enum import Enum
from typing import Any, List
import torch import torch
from torch.testing import FileCheck from torch.testing import FileCheck
from enum import Enum
from typing import Any, List
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestEnum(JitTestCase): class TestEnum(JitTestCase):
def test_enum_value_types(self): def test_enum_value_types(self):
@ -38,11 +41,9 @@ class TestEnum(JitTestCase):
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
return (a.name, b.name, c.name) return (a.name, b.name, c.name)
FileCheck() \ FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
.check("IntEnum") \ str(supported_enum_types.graph)
.check("FloatEnum") \ )
.check("StringEnum") \
.run(str(supported_enum_types.graph))
class TensorEnum(Enum): class TensorEnum(Enum):
FOO = torch.tensor(0) FOO = torch.tensor(0)
@ -54,7 +55,9 @@ class TestEnum(JitTestCase):
return a.name return a.name
# TODO: rewrite code so that the highlight is not empty. # TODO: rewrite code so that the highlight is not empty.
with self.assertRaisesRegexWithHighlight(RuntimeError, "Cannot create Enum with value type 'Tensor'", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
):
torch.jit.script(unsupported_enum_types) torch.jit.script(unsupported_enum_types)
def test_enum_comp(self): def test_enum_comp(self):
@ -88,11 +91,9 @@ class TestEnum(JitTestCase):
def enum_comp(x: Foo) -> bool: def enum_comp(x: Foo) -> bool:
return x == Bar.ITEM1 return x == Bar.ITEM1
FileCheck() \ FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
.check("prim::Constant") \ "aten::eq"
.check_same("Bar.ITEM1") \ ).run(str(enum_comp.graph))
.check("aten::eq") \
.run(str(enum_comp.graph))
self.assertEqual(enum_comp(Foo.ITEM1), False) self.assertEqual(enum_comp(Foo.ITEM1), False)
@ -107,7 +108,9 @@ class TestEnum(JitTestCase):
return x == y return x == y
# TODO: rewrite code so that the highlight is not empty. # TODO: rewrite code so that the highlight is not empty.
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not unify type list", ""): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Could not unify type list", ""
):
torch.jit.script(enum_comp) torch.jit.script(enum_comp)
def test_enum_name(self): def test_enum_name(self):
@ -121,11 +124,9 @@ class TestEnum(JitTestCase):
def enum_name(x: Color) -> str: def enum_name(x: Color) -> str:
return x.name return x.name
FileCheck() \ FileCheck().check("Color").check_next("prim::EnumName").check_next(
.check("Color") \ "return"
.check_next("prim::EnumName") \ ).run(str(enum_name.graph))
.check_next("return") \
.run(str(enum_name.graph))
self.assertEqual(enum_name(Color.RED), Color.RED.name) self.assertEqual(enum_name(Color.RED), Color.RED.name)
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name) self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
@ -141,11 +142,9 @@ class TestEnum(JitTestCase):
def enum_value(x: Color) -> int: def enum_value(x: Color) -> int:
return x.value return x.value
FileCheck() \ FileCheck().check("Color").check_next("prim::EnumValue").check_next(
.check("Color") \ "return"
.check_next("prim::EnumValue") \ ).run(str(enum_value.graph))
.check_next("return") \
.run(str(enum_value.graph))
self.assertEqual(enum_value(Color.RED), Color.RED.value) self.assertEqual(enum_value(Color.RED), Color.RED.value)
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value) self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
@ -161,11 +160,9 @@ class TestEnum(JitTestCase):
def enum_const(x: Color) -> bool: def enum_const(x: Color) -> bool:
return x == Color.RED return x == Color.RED
FileCheck() \ FileCheck().check(
.check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \ "prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
.check_next("aten::eq") \ ).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
.check_next("return") \
.run(str(enum_const.graph))
self.assertEqual(enum_const(Color.RED), True) self.assertEqual(enum_const(Color.RED), True)
self.assertEqual(enum_const(Color.GREEN), False) self.assertEqual(enum_const(Color.GREEN), False)
@ -183,7 +180,9 @@ class TestEnum(JitTestCase):
else: else:
return False return False
with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
):
torch.jit.script(enum_const) torch.jit.script(enum_const)
def test_enum_ivalue_type(self): def test_enum_ivalue_type(self):
@ -197,10 +196,9 @@ class TestEnum(JitTestCase):
def is_color_enum(x: Any): def is_color_enum(x: Any):
return isinstance(x, Color) return isinstance(x, Color)
FileCheck() \ FileCheck().check(
.check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \ "prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
.check_next("return") \ ).check_next("return").run(str(is_color_enum.graph))
.run(str(is_color_enum.graph))
self.assertEqual(is_color_enum(Color.RED), True) self.assertEqual(is_color_enum(Color.RED), True)
self.assertEqual(is_color_enum(Color.GREEN), True) self.assertEqual(is_color_enum(Color.GREEN), True)
@ -217,10 +215,9 @@ class TestEnum(JitTestCase):
def closed_over_aliased_type(): def closed_over_aliased_type():
return a.RED.value return a.RED.value
FileCheck() \ FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
.check("prim::Constant[value={}]".format(a.RED.value)) \ "return"
.check_next("return") \ ).run(str(closed_over_aliased_type.graph))
.run(str(closed_over_aliased_type.graph))
self.assertEqual(closed_over_aliased_type(), Color.RED.value) self.assertEqual(closed_over_aliased_type(), Color.RED.value)
@ -230,10 +227,9 @@ class TestEnum(JitTestCase):
def closed_over_aliased_value(): def closed_over_aliased_value():
return b.value return b.value
FileCheck() \ FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
.check("prim::Constant[value={}]".format(b.value)) \ "return"
.check_next("return") \ ).run(str(closed_over_aliased_value.graph))
.run(str(closed_over_aliased_value.graph))
self.assertEqual(closed_over_aliased_value(), Color.RED.value) self.assertEqual(closed_over_aliased_value(), Color.RED.value)
@ -253,13 +249,9 @@ class TestEnum(JitTestCase):
m = TestModule(Color.RED) m = TestModule(Color.RED)
scripted = torch.jit.script(m) scripted = torch.jit.script(m)
FileCheck() \ FileCheck().check("TestModule").check_next("Color").check_same(
.check("TestModule") \ 'prim::GetAttr[name="e"]'
.check_next("Color") \ ).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("prim::EnumValue") \
.check_next("return") \
.run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED.value) self.assertEqual(scripted(), Color.RED.value)
@ -316,16 +308,12 @@ class TestEnum(JitTestCase):
m = TestModule(Color.RED) m = TestModule(Color.RED)
scripted = torch.jit.script(m) scripted = torch.jit.script(m)
FileCheck() \ FileCheck().check("TestModule").check_next("Color").check_same(
.check("TestModule") \ 'prim::GetAttr[name="e"]'
.check_next("Color") \ ).check_next("return").run(str(scripted.graph))
.check_same("prim::GetAttr[name=\"e\"]") \
.check_next("return") \
.run(str(scripted.graph))
self.assertEqual(scripted(), Color.RED) self.assertEqual(scripted(), Color.RED)
def test_enum_iterate(self): def test_enum_iterate(self):
class Color(Enum): class Color(Enum):
RED = 1 RED = 1
@ -342,12 +330,9 @@ class TestEnum(JitTestCase):
make_global(Color) make_global(Color)
scripted = torch.jit.script(iterate_enum) scripted = torch.jit.script(iterate_enum)
FileCheck() \ FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
.check("Enum<__torch__.jit.test_enum.Color>[]") \ "Color.RED"
.check_same("Color.RED") \ ).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
.check_same("Color.GREEN") \
.check_same("Color.BLUE") \
.run(str(scripted.graph))
# PURPLE always appears last because we follow Python's Enum definition order. # PURPLE always appears last because we follow Python's Enum definition order.
self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value]) self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
@ -355,7 +340,6 @@ class TestEnum(JitTestCase):
# Tests that explicitly and/or repeatedly scripting an Enum class is permitted. # Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
def test_enum_explicit_script(self): def test_enum_explicit_script(self):
@torch.jit.script @torch.jit.script
class Color(Enum): class Color(Enum):
RED = 1 RED = 1

View File

@ -1,11 +1,13 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from torch.testing._internal.common_utils import TestCase
import torch import torch
from torch import nn from torch import nn
from torch.testing._internal.common_utils import TestCase
r""" r"""
Test TorchScript exception handling. Test TorchScript exception handling.
""" """
class TestException(TestCase): class TestException(TestCase):
def test_pyop_exception_message(self): def test_pyop_exception_message(self):
class Foo(torch.jit.ScriptModule): class Foo(torch.jit.ScriptModule):
@ -16,31 +18,40 @@ class TestException(TestCase):
@torch.jit.script_method @torch.jit.script_method
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
foo = Foo() foo = Foo()
# testing that the correct error message propagates # testing that the correct error message propagates
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"): with self.assertRaisesRegex(
RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
):
foo(torch.ones([123])) # wrong size foo(torch.ones([123])) # wrong size
def test_builtin_error_messsage(self): def test_builtin_error_messsage(self):
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
@torch.jit.script @torch.jit.script
def close_match(x): def close_match(x):
return x.masked_fill(True) return x.masked_fill(True)
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently " with self.assertRaisesRegex(
"supported in TorchScript"): RuntimeError,
"This op may not exist or may not be currently " "supported in TorchScript",
):
@torch.jit.script @torch.jit.script
def unknown_op(x): def unknown_op(x):
torch.set_anomaly_enabled(True) torch.set_anomaly_enabled(True)
return x return x
def test_exceptions(self): def test_exceptions(self):
cu = torch.jit.CompilationUnit(''' cu = torch.jit.CompilationUnit(
"""
def foo(cond): def foo(cond):
if bool(cond): if bool(cond):
raise ValueError(3) raise ValueError(3)
return 1 return 1
''') """
)
cu.foo(torch.tensor(0)) cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "3"): with self.assertRaisesRegex(torch.jit.Error, "3"):
@ -97,6 +108,7 @@ class TestException(TestCase):
else: else:
raise Exception("Hi") raise Exception("Hi")
return a return a
self.assertEqual(foo(), 1) self.assertEqual(foo(), 1)
@torch.jit.script @torch.jit.script
@ -114,11 +126,13 @@ class TestException(TestCase):
no_message() no_message()
def test_assertions(self): def test_assertions(self):
cu = torch.jit.CompilationUnit(''' cu = torch.jit.CompilationUnit(
"""
def foo(cond): def foo(cond):
assert bool(cond), "hi" assert bool(cond), "hi"
return 0 return 0
''') """
)
cu.foo(torch.tensor(1)) cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"): with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
@ -142,7 +156,9 @@ class TestException(TestCase):
def fn(x): def fn(x):
return python_op(x) return python_op(x)
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"): with self.assertRaisesRegex(
RuntimeError, "operation failed in the TorchScript interpreter"
):
fn(torch.tensor(4)) fn(torch.tensor(4))
def test_dict_expansion_raises_error(self): def test_dict_expansion_raises_error(self):
@ -150,8 +166,9 @@ class TestException(TestCase):
d = {"foo": 1, "bar": 2, "baz": 3} d = {"foo": 1, "bar": 2, "baz": 3}
return {**d} return {**d}
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, with self.assertRaisesRegex(
"Dict expansion "): torch.jit.frontend.NotSupportedError, "Dict expansion "
):
torch.jit.script(fn) torch.jit.script(fn)
def test_custom_python_exception(self): def test_custom_python_exception(self):
@ -162,7 +179,9 @@ class TestException(TestCase):
def fn(): def fn():
raise MyValueError("test custom exception") raise MyValueError("test custom exception")
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"): with self.assertRaisesRegex(
torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
):
fn() fn()
def test_custom_python_exception_defined_elsewhere(self): def test_custom_python_exception_defined_elsewhere(self):
@ -171,5 +190,9 @@ class TestException(TestCase):
@torch.jit.script @torch.jit.script
def fn(): def fn():
raise MyKeyError("This is a user defined key error") raise MyKeyError("This is a user defined key error")
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
with self.assertRaisesRegex(
torch.jit.Error,
"jit.myexception.MyKeyError: This is a user defined key error",
):
fn() fn()

File diff suppressed because it is too large Load Diff

View File

@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestFunctionalBlocks(JitTestCase): class TestFunctionalBlocks(JitTestCase):
def test_subgraph_creation(self): def test_subgraph_creation(self):
@ -30,14 +33,22 @@ class TestFunctionalBlocks(JitTestCase):
return x + y + z return x + y + z
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass('create_functional_graphs', graph) self.run_pass("create_functional_graphs", graph)
# all uses of x and y should be sunk # all uses of x and y should be sunk
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph) FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph) r"%x"
).run(graph)
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(
r"%y"
).run(graph)
# Don't allow any outputs which escape scope, so there is one final addition in the graph # Don't allow any outputs which escape scope, so there is one final addition in the graph
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph) FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(
graph
)
# z + 1, z.add_(2) considered non functional, z = z * z should be considered functional # z + 1, z.add_(2) considered non functional, z = z * z should be considered functional
FileCheck().check("add").check("add_").check_not("mul").check("FunctionalGraph").run(graph) FileCheck().check("add").check("add_").check_not("mul").check(
"FunctionalGraph"
).run(graph)

View File

@ -3,9 +3,11 @@
import torch import torch
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
class TestFuserCommon(JitTestCase): class TestFuserCommon(JitTestCase):
def test_autodiff_fallback(self): def test_autodiff_fallback(self):
for rq in [True, False]: for rq in [True, False]:
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
return torch.max(x**2.0, x**3.0) return torch.max(x**2.0, x**3.0)

View File

@ -1,9 +1,9 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from torch.testing._internal.jit_utils import JitTestCase
import torch import torch
import torch._C import torch._C
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
class TestGraphRewritePasses(JitTestCase): class TestGraphRewritePasses(JitTestCase):

View File

@ -3,9 +3,9 @@
import os import os
import sys import sys
import torch from typing import List, Tuple
from typing import Tuple, List import torch
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -13,9 +13,12 @@ sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestHash(JitTestCase): class TestHash(JitTestCase):
def test_hash_tuple(self): def test_hash_tuple(self):
@ -38,6 +41,7 @@ class TestHash(JitTestCase):
def test_hash_tensor(self): def test_hash_tensor(self):
"""Tensors should hash by identity""" """Tensors should hash by identity"""
def fn(t1, t2): def fn(t1, t2):
return hash(t1) == hash(t2) return hash(t1) == hash(t2)
@ -74,7 +78,7 @@ class TestHash(JitTestCase):
self.checkScript(fn, (1.2345, 6.789)) self.checkScript(fn, (1.2345, 6.789))
self.checkScript(fn, (1.2345, float("inf"))) self.checkScript(fn, (1.2345, float("inf")))
self.checkScript(fn, (float("inf"), float("inf"))) self.checkScript(fn, (float("inf"), float("inf")))
self.checkScript(fn, (1.2345, float('nan'))) self.checkScript(fn, (1.2345, float("nan")))
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
# Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html : # Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
# Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity. # Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
@ -103,9 +107,9 @@ class TestHash(JitTestCase):
def fn(d1: torch.device, d2: torch.device): def fn(d1: torch.device, d2: torch.device):
return hash(d1) == hash(d2) return hash(d1) == hash(d2)
gpu0 = torch.device('cuda:0') gpu0 = torch.device("cuda:0")
gpu1 = torch.device('cuda:1') gpu1 = torch.device("cuda:1")
cpu = torch.device('cpu') cpu = torch.device("cpu")
self.checkScript(fn, (gpu0, gpu0)) self.checkScript(fn, (gpu0, gpu0))
self.checkScript(fn, (gpu0, gpu1)) self.checkScript(fn, (gpu0, gpu1))
self.checkScript(fn, (gpu0, cpu)) self.checkScript(fn, (gpu0, cpu))

View File

@ -6,21 +6,29 @@ import unittest
from typing import Tuple from typing import Tuple
import torch import torch
from jit.test_hooks_modules import ( from jit.test_hooks_modules import (
ModuleDirectforwardSubmodCall, ModuleForwardSingleInput, create_forward_tuple_input,
ModuleForwardTupleInput, create_forward_tuple_input, create_module_forward_multiple_inputs,
create_module_forward_multiple_inputs, create_module_forward_single_input, create_module_forward_single_input,
create_module_hook_return_nothing, create_module_hook_return_nothing,
create_module_multiple_hooks_multiple_inputs, create_module_multiple_hooks_multiple_inputs,
create_module_multiple_hooks_single_input, create_module_no_forward_input, create_module_multiple_hooks_single_input,
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs, create_module_no_forward_input,
create_module_same_hook_repeated,
create_submodule_forward_multiple_inputs,
create_submodule_forward_single_input, create_submodule_forward_single_input,
create_submodule_forward_single_input_return_not_tupled, create_submodule_forward_single_input_return_not_tupled,
create_submodule_hook_return_nothing, create_submodule_hook_return_nothing,
create_submodule_multiple_hooks_multiple_inputs, create_submodule_multiple_hooks_multiple_inputs,
create_submodule_multiple_hooks_single_input, create_submodule_multiple_hooks_single_input,
create_submodule_no_forward_input, create_submodule_same_hook_repeated, create_submodule_no_forward_input,
create_submodule_to_call_directly_with_hooks) create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks,
ModuleDirectforwardSubmodCall,
ModuleForwardSingleInput,
ModuleForwardTupleInput,
)
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -37,7 +45,6 @@ if __name__ == "__main__":
# Tests for JIT forward hooks and pre-hooks # Tests for JIT forward hooks and pre-hooks
class TestHooks(JitTestCase): class TestHooks(JitTestCase):
def test_module_no_forward_input(self): def test_module_no_forward_input(self):
self.checkModule(create_module_no_forward_input(), ()) self.checkModule(create_module_no_forward_input(), ())
@ -73,7 +80,8 @@ class TestHooks(JitTestCase):
def test_submodule_multiple_hooks_multiple_inputs(self): def test_submodule_multiple_hooks_multiple_inputs(self):
self.checkModule( self.checkModule(
create_submodule_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook"), create_submodule_multiple_hooks_multiple_inputs(),
(["a"], "no_pre_hook"),
) )
def test_submodule_forward_single_input(self): def test_submodule_forward_single_input(self):
@ -242,7 +250,8 @@ class TestHooks(JitTestCase):
m.register_forward_pre_hook(pre_hook_wrong_input1) m.register_forward_pre_hook(pre_hook_wrong_input1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "has the wrong inner types for the input tuple argument", RuntimeError,
"has the wrong inner types for the input tuple argument",
): ):
torch.jit.script(m) torch.jit.script(m)
@ -278,7 +287,8 @@ class TestHooks(JitTestCase):
m.register_forward_pre_hook(pre_hook_wrong_output) m.register_forward_pre_hook(pre_hook_wrong_output)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, "returned the wrong type of: 'int'", RuntimeError,
"returned the wrong type of: 'int'",
): ):
torch.jit.script(m) torch.jit.script(m)

View File

@ -1,8 +1,9 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch
from typing import List, Tuple from typing import List, Tuple
import torch
class SubmoduleNoForwardInputs(torch.nn.Module): class SubmoduleNoForwardInputs(torch.nn.Module):
def __init__(self, name): def __init__(self, name):

View File

@ -2,6 +2,7 @@
import os import os
import sys import sys
import torch import torch
from torch._C import parse_ir from torch._C import parse_ir
from torch.testing import FileCheck from torch.testing import FileCheck
@ -11,10 +12,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript # Tests that Python slice class is supported in TorchScript
class TestIgnorableArgs(JitTestCase): class TestIgnorableArgs(JitTestCase):
@ -44,11 +48,14 @@ class TestIgnorableArgs(JitTestCase):
# We ignore trailing arguments after start=2 for dim 0 # We ignore trailing arguments after start=2 for dim 0
# and after end=1 for dim 1 # and after end=1 for dim 1
# because in %16, %15 and %0 are default values for the schema. # because in %16, %15 and %0 are default values for the schema.
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src) FileCheck().check(
"torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)"
).run(src)
self.assertEqual(function(), function_copy()) self.assertEqual(function(), function_copy())
def test_add_out_ignorable_args(self): def test_add_out_ignorable_args(self):
@torch.jit.script @torch.jit.script
def fn(x: torch.Tensor, y: torch.Tensor): def fn(x: torch.Tensor, y: torch.Tensor):
torch.add(x, y, out=y) torch.add(x, y, out=y)
FileCheck().check("torch.add(x, y, out=y)").run(fn.code) FileCheck().check("torch.add(x, y, out=y)").run(fn.code)

View File

@ -9,13 +9,16 @@ import torch
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestIgnoreContextManager(JitTestCase): class TestIgnoreContextManager(JitTestCase):
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
@ -26,11 +29,14 @@ class TestIgnoreContextManager(JitTestCase):
b: int = 5 b: int = 5
c: int = 0 c: int = 0
d: int = 6 d: int = 6
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"): with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int", d="out:int"
):
l = [2 for i in range(a) if i > 2] l = [2 for i in range(a) if i > 2]
c = l[0] + a + b c = l[0] + a + b
d = 9 d = 9
return c + d return c + d
model = A() model = A()
s = torch.jit.script(model) s = torch.jit.script(model)
self.assertEqual(s(), model()) self.assertEqual(s(), model())
@ -41,10 +47,13 @@ class TestIgnoreContextManager(JitTestCase):
a: int = 4 a: int = 4
b: int = 5 b: int = 5
c: int = 0 c: int = 0
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"): with torch.jit._IgnoreContextManager(
a="inp:int", b="inp:int", c="out:int"
):
l = [2 for i in range(a) if i > 2] l = [2 for i in range(a) if i > 2]
c = l[0] + a + b c = l[0] + a + b
return c return c
model = B() model = B()
s = torch.jit.script(model) s = torch.jit.script(model)
self.assertEqual(s(), 11) self.assertEqual(s(), 11)
@ -58,6 +67,7 @@ class TestIgnoreContextManager(JitTestCase):
l = [2 for i in range(a) if i > 2] l = [2 for i in range(a) if i > 2]
b = l[0] + a b = l[0] + a
return b return b
model = C() model = C()
s = torch.jit.script(model) s = torch.jit.script(model)
self.assertEqual(s(), 6) self.assertEqual(s(), 6)
@ -72,6 +82,7 @@ class TestIgnoreContextManager(JitTestCase):
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"): with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
l = [2 + b for i in range(a) if i > 2] l = [2 + b for i in range(a) if i > 2]
return a return a
model = A() model = A()
s = torch.jit.script(model) s = torch.jit.script(model)
self.assertEqual(s(), 4) self.assertEqual(s(), 4)
@ -85,6 +96,7 @@ class TestIgnoreContextManager(JitTestCase):
c = [2 for i in range(7) if i > 2] c = [2 for i in range(7) if i > 2]
c[0] = 3 c[0] = 3
return c[0] + c[1] return c[0] + c[1]
model = A() model = A()
s = torch.jit.script(model) s = torch.jit.script(model)
self.assertEqual(s(), 5) self.assertEqual(s(), 5)

View File

@ -2,10 +2,10 @@
import os import os
import sys import sys
import warnings
from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import warnings
from typing import List, Any, Dict, Tuple, Optional
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -19,6 +19,7 @@ if __name__ == "__main__":
"instead." "instead."
) )
# Tests for torch.jit.isinstance # Tests for torch.jit.isinstance
class TestIsinstance(JitTestCase): class TestIsinstance(JitTestCase):
def test_int(self): def test_int(self):
@ -223,28 +224,42 @@ class TestIsinstance(JitTestCase):
x = ["1", "2", "3"] x = ["1", "2", "3"]
err_msg = "Attempted to use List without a contained type. " \ err_msg = (
"Attempted to use List without a contained type. "
r"Please add a contained type, e.g. List\[int\]" r"Please add a contained type, e.g. List\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(list_no_contained_type) torch.jit.script(list_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
list_no_contained_type(x) list_no_contained_type(x)
def test_tuple_no_contained_type(self): def test_tuple_no_contained_type(self):
def tuple_no_contained_type(x: Any): def tuple_no_contained_type(x: Any):
assert torch.jit.isinstance(x, Tuple) assert torch.jit.isinstance(x, Tuple)
x = ("1", "2", "3") x = ("1", "2", "3")
err_msg = "Attempted to use Tuple without a contained type. " \ err_msg = (
"Attempted to use Tuple without a contained type. "
r"Please add a contained type, e.g. Tuple\[int\]" r"Please add a contained type, e.g. Tuple\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(tuple_no_contained_type) torch.jit.script(tuple_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
tuple_no_contained_type(x) tuple_no_contained_type(x)
def test_optional_no_contained_type(self): def test_optional_no_contained_type(self):
@ -253,12 +268,20 @@ class TestIsinstance(JitTestCase):
x = ("1", "2", "3") x = ("1", "2", "3")
err_msg = "Attempted to use Optional without a contained type. " \ err_msg = (
"Attempted to use Optional without a contained type. "
r"Please add a contained type, e.g. Optional\[int\]" r"Please add a contained type, e.g. Optional\[int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(optional_no_contained_type) torch.jit.script(optional_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
optional_no_contained_type(x) optional_no_contained_type(x)
def test_dict_no_contained_type(self): def test_dict_no_contained_type(self):
@ -267,12 +290,20 @@ class TestIsinstance(JitTestCase):
x = {"a": "aa"} x = {"a": "aa"}
err_msg = "Attempted to use Dict without contained types. " \ err_msg = (
"Attempted to use Dict without contained types. "
r"Please add contained type, e.g. Dict\[int, int\]" r"Please add contained type, e.g. Dict\[int, int\]"
)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
torch.jit.script(dict_no_contained_type) torch.jit.script(dict_no_contained_type)
with self.assertRaisesRegex(RuntimeError, err_msg,): with self.assertRaisesRegex(
RuntimeError,
err_msg,
):
dict_no_contained_type(x) dict_no_contained_type(x)
def test_tuple_rhs(self): def test_tuple_rhs(self):

View File

@ -13,10 +13,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests various JIT-related utility functions. # Tests various JIT-related utility functions.
class TestJitUtils(JitTestCase): class TestJitUtils(JitTestCase):
@ -24,58 +27,71 @@ class TestJitUtils(JitTestCase):
def test_get_callable_argument_names_positional_or_keyword(self): def test_get_callable_argument_names_positional_or_keyword(self):
def fn_positional_or_keyword_args_only(x, y): def fn_positional_or_keyword_args_only(x, y):
return x + y return x + y
self.assertEqual( self.assertEqual(
["x", "y"], ["x", "y"],
torch._jit_internal.get_callable_argument_names(fn_positional_or_keyword_args_only)) torch._jit_internal.get_callable_argument_names(
fn_positional_or_keyword_args_only
),
)
# Tests that POSITIONAL_ONLY arguments are ignored. # Tests that POSITIONAL_ONLY arguments are ignored.
def test_get_callable_argument_names_positional_only(self): def test_get_callable_argument_names_positional_only(self):
code = dedent(''' code = dedent(
"""
def fn_positional_only_arg(x, /, y): def fn_positional_only_arg(x, /, y):
return x + y return x + y
''') """
)
fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg') fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
self.assertEqual( self.assertEqual(
["y"], ["y"],
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg)) torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
)
# Tests that VAR_POSITIONAL arguments are ignored. # Tests that VAR_POSITIONAL arguments are ignored.
def test_get_callable_argument_names_var_positional(self): def test_get_callable_argument_names_var_positional(self):
# Tests that VAR_POSITIONAL arguments are ignored. # Tests that VAR_POSITIONAL arguments are ignored.
def fn_var_positional_arg(x, *arg): def fn_var_positional_arg(x, *arg):
return x + arg[0] return x + arg[0]
self.assertEqual( self.assertEqual(
["x"], ["x"],
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg)) torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
)
# Tests that KEYWORD_ONLY arguments are ignored. # Tests that KEYWORD_ONLY arguments are ignored.
def test_get_callable_argument_names_keyword_only(self): def test_get_callable_argument_names_keyword_only(self):
def fn_keyword_only_arg(x, *, y): def fn_keyword_only_arg(x, *, y):
return x + y return x + y
self.assertEqual( self.assertEqual(
["x"], ["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)) )
# Tests that VAR_KEYWORD arguments are ignored. # Tests that VAR_KEYWORD arguments are ignored.
def test_get_callable_argument_names_var_keyword(self): def test_get_callable_argument_names_var_keyword(self):
def fn_var_keyword_arg(**args): def fn_var_keyword_arg(**args):
return args['x'] + args['y'] return args["x"] + args["y"]
self.assertEqual( self.assertEqual(
[], [], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)) )
# Tests that a function signature containing various different types of # Tests that a function signature containing various different types of
# arguments are ignored. # arguments are ignored.
def test_get_callable_argument_names_hybrid(self): def test_get_callable_argument_names_hybrid(self):
code = dedent(''' code = dedent(
"""
def fn_hybrid_args(x, /, y, *args, **kwargs): def fn_hybrid_args(x, /, y, *args, **kwargs):
return x + y + args[0] + kwargs['z'] return x + y + args[0] + kwargs['z']
''') """
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args') )
fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
self.assertEqual( self.assertEqual(
["y"], ["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
torch._jit_internal.get_callable_argument_names(fn_hybrid_args)) )
def test_checkscriptassertraisesregex(self): def test_checkscriptassertraisesregex(self):
def fn(): def fn():
@ -84,22 +100,18 @@ class TestJitUtils(JitTestCase):
self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn") self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
s = dedent(""" s = dedent(
"""
def fn(): def fn():
tup = (1, 2) tup = (1, 2)
return tup[2] return tup[2]
""") """
)
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn") self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
def test_no_tracer_warn_context_manager(self): def test_no_tracer_warn_context_manager(self):
torch._C._jit_set_tracer_state_warn(True) torch._C._jit_set_tracer_state_warn(True)
with jit_utils.NoTracerWarnContextManager() as no_warn: with jit_utils.NoTracerWarnContextManager() as no_warn:
self.assertEqual( self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
False, self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
torch._C._jit_get_tracer_state_warn()
)
self.assertEqual(
True,
torch._C._jit_get_tracer_state_warn()
)

File diff suppressed because it is too large Load Diff

View File

@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestLogging(JitTestCase): class TestLogging(JitTestCase):
def test_bump_numeric_counter(self): def test_bump_numeric_counter(self):
@ -22,30 +25,29 @@ class TestLogging(JitTestCase):
def forward(self, x): def forward(self, x):
for i in range(x.size(0)): for i in range(x.size(0)):
x += 1.0 x += 1.0
torch.jit._logging.add_stat_value('foo', 1) torch.jit._logging.add_stat_value("foo", 1)
if bool(x.sum() > 0.0): if bool(x.sum() > 0.0):
torch.jit._logging.add_stat_value('positive', 1) torch.jit._logging.add_stat_value("positive", 1)
else: else:
torch.jit._logging.add_stat_value('negative', 1) torch.jit._logging.add_stat_value("negative", 1)
return x return x
logger = torch.jit._logging.LockingLogger() logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger) old_logger = torch.jit._logging.set_logger(logger)
try: try:
mtl = ModuleThatLogs() mtl = ModuleThatLogs()
for i in range(5): for i in range(5):
mtl(torch.rand(3, 4, 5)) mtl(torch.rand(3, 4, 5))
self.assertEqual(logger.get_counter_val('foo'), 15) self.assertEqual(logger.get_counter_val("foo"), 15)
self.assertEqual(logger.get_counter_val('positive'), 5) self.assertEqual(logger.get_counter_val("positive"), 5)
finally: finally:
torch.jit._logging.set_logger(old_logger) torch.jit._logging.set_logger(old_logger)
def test_trace_numeric_counter(self): def test_trace_numeric_counter(self):
def foo(x): def foo(x):
torch.jit._logging.add_stat_value('foo', 1) torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0 return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4)) traced = torch.jit.trace(foo, torch.rand(3, 4))
@ -54,7 +56,7 @@ class TestLogging(JitTestCase):
try: try:
traced(torch.rand(3, 4)) traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1) self.assertEqual(logger.get_counter_val("foo"), 1)
finally: finally:
torch.jit._logging.set_logger(old_logger) torch.jit._logging.set_logger(old_logger)
@ -65,7 +67,7 @@ class TestLogging(JitTestCase):
for i in range(30): for i in range(30):
x += 1.0 x += 1.0
tp_end = torch.jit._logging.time_point() tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x return x
mtm = ModuleThatTimes() mtm = ModuleThatTimes()
@ -73,7 +75,7 @@ class TestLogging(JitTestCase):
old_logger = torch.jit._logging.set_logger(logger) old_logger = torch.jit._logging.set_logger(logger)
try: try:
mtm(torch.rand(3, 4)) mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0) self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally: finally:
torch.jit._logging.set_logger(old_logger) torch.jit._logging.set_logger(old_logger)
@ -85,7 +87,7 @@ class TestLogging(JitTestCase):
for i in range(30): for i in range(30):
x += 1.0 x += 1.0
tp_end = torch.jit._logging.time_point() tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x return x
mtm = ModuleThatTimes() mtm = ModuleThatTimes()
@ -93,27 +95,27 @@ class TestLogging(JitTestCase):
old_logger = torch.jit._logging.set_logger(logger) old_logger = torch.jit._logging.set_logger(logger)
try: try:
mtm(torch.rand(3, 4)) mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0) self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally: finally:
torch.jit._logging.set_logger(old_logger) torch.jit._logging.set_logger(old_logger)
def test_counter_aggregation(self): def test_counter_aggregation(self):
def foo(x): def foo(x):
for i in range(3): for i in range(3):
torch.jit._logging.add_stat_value('foo', 1) torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0 return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4)) traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger() logger = torch.jit._logging.LockingLogger()
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG) logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
old_logger = torch.jit._logging.set_logger(logger) old_logger = torch.jit._logging.set_logger(logger)
try: try:
traced(torch.rand(3, 4)) traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1) self.assertEqual(logger.get_counter_val("foo"), 1)
finally: finally:
torch.jit._logging.set_logger(old_logger) torch.jit._logging.set_logger(old_logger)
def test_logging_levels_set(self): def test_logging_levels_set(self):
torch._C._jit_set_logging_option('foo') torch._C._jit_set_logging_option("foo")
self.assertEqual('foo', torch._C._jit_get_logging_option()) self.assertEqual("foo", torch._C._jit_get_logging_option())

View File

@ -1,28 +1,32 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from typing import Any, Dict, List, Optional, Tuple
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing import FileCheck
from torch import jit
from jit.test_module_interface import TestModuleInterface # noqa: F401
import os import os
import sys import sys
import torch
import torch.testing._internal.jit_utils
import torch.nn as nn
import unittest import unittest
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.testing._internal.jit_utils
from torch import jit
from torch.testing import FileCheck
from torch.testing._internal.common_utils import freeze_rng_state from torch.testing._internal.common_utils import freeze_rng_state
from torch.testing._internal.jit_utils import RUN_CUDA_HALF
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
from jit.test_module_interface import TestModuleInterface # noqa: F401
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestMisc(JitTestCase): class TestMisc(JitTestCase):
def test_joined_str(self): def test_joined_str(self):
@ -30,12 +34,12 @@ class TestMisc(JitTestCase):
hello, test = "Hello", "test" hello, test = "Hello", "test"
print(f"{hello + ' ' + test}, I'm a {test}") print(f"{hello + ' ' + test}, I'm a {test}")
print("format blank") print("format blank")
hi = 'hi' hi = "hi"
print(f"stuff before {hi}") print(f"stuff before {hi}")
print(f"{hi} stuff after") print(f"{hi} stuff after")
return x + 1 return x + 1
x = torch.arange(4., requires_grad=True) x = torch.arange(4.0, requires_grad=True)
# TODO: Add support for f-strings in string parser frontend # TODO: Add support for f-strings in string parser frontend
# self.checkScript(func, [x], optimize=True, capture_output=True) # self.checkScript(func, [x], optimize=True, capture_output=True)
@ -50,10 +54,14 @@ class TestMisc(JitTestCase):
self.assertEqual(captured, captured_script) self.assertEqual(captured, captured_script)
def test_kwarg_support(self): def test_kwarg_support(self):
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"): with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError, "variable number of arguments"
):
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, *, n_tokens: int, device_name: str = 2): def forward(self, *, n_tokens: int, device_name: str = 2):
pass pass
torch.jit.script(M()) torch.jit.script(M())
class M(torch.nn.Module): class M(torch.nn.Module):
@ -62,32 +70,35 @@ class TestMisc(JitTestCase):
sm = torch.jit.script(M()) sm = torch.jit.script(M())
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"): with self.assertRaisesRegex(
RuntimeError, "missing value for argument 'n_tokens'"
):
sm() sm()
with self.assertRaisesRegex(RuntimeError, "positional arg"): with self.assertRaisesRegex(RuntimeError, "positional arg"):
sm(3, 'hello') sm(3, "hello")
self.assertEqual(sm(n_tokens=3, device_name='hello'), (3, 'hello')) self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
def test_tuple_subscripted_assign(self): def test_tuple_subscripted_assign(self):
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"): with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
@torch.jit.script @torch.jit.script
def foo(a: Tuple[int, int]) -> None: def foo(a: Tuple[int, int]) -> None:
a[0] = a[1] a[0] = a[1]
with self.assertRaisesRegex(RuntimeError, "augmented assignment"): with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
@torch.jit.script @torch.jit.script
def bar(a: Tuple[int, int]) -> None: def bar(a: Tuple[int, int]) -> None:
a[0] += a[1] a[0] += a[1]
def test_subexpression_List_Future(self): def test_subexpression_List_Future(self):
@torch.jit.script @torch.jit.script
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]: def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
return x[0] return x[0]
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph) FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
def test_subexpression_Future_annotate(self): def test_subexpression_Future_annotate(self):
@torch.jit.script @torch.jit.script
@ -110,36 +121,40 @@ class TestMisc(JitTestCase):
if isinstance(x, str): if isinstance(x, str):
return x return x
return "foo" return "foo"
forward = torch.jit.script(forward) forward = torch.jit.script(forward)
self.assertEqual(forward(1), "foo") self.assertEqual(forward(1), "foo")
self.assertEqual(forward("bar"), "bar") self.assertEqual(forward("bar"), "bar")
def test_subexpression_Tuple_int_int_Future(self): def test_subexpression_Tuple_int_int_Future(self):
@torch.jit.script @torch.jit.script
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]: def fn(
x: Tuple[int, int, torch.jit.Future[int]]
) -> Tuple[int, torch.jit.Future[int]]:
return x[0], x[2] return x[0], x[2]
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph) FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
fn.graph
)
def test_subexpression_Dict_int_Future(self): def test_subexpression_Dict_int_Future(self):
@torch.jit.script @torch.jit.script
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]: def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
return x[y] return x[y]
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph) FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
def test_subexpression_Optional(self): def test_subexpression_Optional(self):
@torch.jit.script @torch.jit.script
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]: def fn(
x: Optional[Dict[int, torch.jit.Future[int]]]
) -> Optional[torch.jit.Future[int]]:
if x is not None: if x is not None:
return x[0] return x[0]
else: else:
return None return None
FileCheck().check('Dict(int, Future(int))?').run(fn.graph) FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
def test_if_returning_any(self): def test_if_returning_any(self):
""" """
@ -147,6 +162,7 @@ class TestMisc(JitTestCase):
types early from each branch when the return types early from each branch when the return
type of the function is Any. type of the function is Any.
""" """
def if_function(inp: torch.Tensor) -> Any: def if_function(inp: torch.Tensor) -> Any:
if inp.shape[0] == 1: if inp.shape[0] == 1:
return inp * inp return inp * inp
@ -156,14 +172,23 @@ class TestMisc(JitTestCase):
self.checkScript(if_function, (torch.randn(5),)) self.checkScript(if_function, (torch.randn(5),))
def test_hacked_twin(self): def test_hacked_twin(self):
def gen_data(): def gen_data():
with freeze_rng_state(): with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
input, index, value, = gen_data() (
input1, index1, value1, = gen_data() input,
out1 = torch.ops.aten.index_put.hacked_twin(input, [index], value, accumulate=False) index,
value,
) = gen_data()
(
input1,
index1,
value1,
) = gen_data()
out1 = torch.ops.aten.index_put.hacked_twin(
input, [index], value, accumulate=False
)
out2 = torch.index_put(input1, [index1], value1, accumulate=False) out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2) self.assertEqual(out1, out2)
@ -172,14 +197,23 @@ class TestMisc(JitTestCase):
self.assertEqual(input, input1) self.assertEqual(input, input1)
def test_unsafe_hacked_twin(self): def test_unsafe_hacked_twin(self):
def gen_data(): def gen_data():
with freeze_rng_state(): with freeze_rng_state():
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20) return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
input, index, value, = gen_data() (
input1, index1, value1, = gen_data() input,
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False) index,
value,
) = gen_data()
(
input1,
index1,
value1,
) = gen_data()
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
input, [index], value, accumulate=False
)
out2 = torch.index_put(input1, [index1], value1, accumulate=False) out2 = torch.index_put(input1, [index1], value1, accumulate=False)
self.assertEqual(out1, out2) self.assertEqual(out1, out2)
@ -188,7 +222,9 @@ class TestMisc(JitTestCase):
self.assertEqual(input, input1) self.assertEqual(input, input1)
def index_put_fn(input, index, value): def index_put_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False) return torch.ops.aten._unsafe_index_put(
input, [index], value, accumulate=False
)
input2, index2, value2 = gen_data() input2, index2, value2 = gen_data()
script_index_put_fn = torch.jit.script(index_put_fn) script_index_put_fn = torch.jit.script(index_put_fn)
@ -197,7 +233,9 @@ class TestMisc(JitTestCase):
self.assertEqual(expect, actual) self.assertEqual(expect, actual)
def index_fn(input, index, value): def index_fn(input, index, value):
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False) return torch.ops.aten._unsafe_index_put(
input, [index], value, accumulate=False
)
script_index_fn = torch.jit.script(index_fn) script_index_fn = torch.jit.script(index_fn)
expect = index_fn(input2.clone(), index2, value2) expect = index_fn(input2.clone(), index2, value2)
@ -205,7 +243,6 @@ class TestMisc(JitTestCase):
self.assertEqual(expect, actual) self.assertEqual(expect, actual)
def test_export_opnames_interface(self): def test_export_opnames_interface(self):
@torch.jit.interface @torch.jit.interface
class OneTwoModule(nn.Module): class OneTwoModule(nn.Module):
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@ -240,7 +277,7 @@ class TestMisc(JitTestCase):
make_global(OneTwoModule) make_global(OneTwoModule)
class M(nn.Module): class M(nn.Module):
sub : OneTwoModule sub: OneTwoModule
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -254,12 +291,18 @@ class TestMisc(JitTestCase):
torch._C._enable_mobile_interface_call_export() torch._C._enable_mobile_interface_call_export()
scripted_M_mod = torch.jit.script(M()) scripted_M_mod = torch.jit.script(M())
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset( self.assertTrue(
set(torch.jit.export_opnames(scripted_M_mod)))) {"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))
)
)
scripted_M_mod.sub = torch.jit.script(FooMod()) scripted_M_mod.sub = torch.jit.script(FooMod())
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset( self.assertTrue(
set(torch.jit.export_opnames(scripted_M_mod)))) {"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
set(torch.jit.export_opnames(scripted_M_mod))
)
)
def test_math_inf(self): def test_math_inf(self):
from math import inf from math import inf
@ -292,7 +335,6 @@ class TestMisc(JitTestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
torch.jit.script(non_temporary_fail) torch.jit.script(non_temporary_fail)
@torch.jit.script @torch.jit.script
def test_return(): def test_return():
return [] return []
@ -335,7 +377,9 @@ class TestMisc(JitTestCase):
def multiple_args(): def multiple_args():
return torch.LongTensor(1, [2]) return torch.LongTensor(1, [2])
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"): with self.assertRaisesRegex(
RuntimeError, "multiple positional arguments that were not all integers"
):
torch.jit.script(multiple_args) torch.jit.script(multiple_args)
# kwarg bad schema # kwarg bad schema
@ -345,7 +389,6 @@ class TestMisc(JitTestCase):
with self.assertRaisesRegex(RuntimeError, "hello"): with self.assertRaisesRegex(RuntimeError, "hello"):
torch.jit.script(bad_kwarg) torch.jit.script(bad_kwarg)
def test_broadcasting_list(self): def test_broadcasting_list(self):
""" """
Test BroadcastingList and torch.nn._size_N_t alias Test BroadcastingList and torch.nn._size_N_t alias
@ -360,7 +403,7 @@ class TestMisc(JitTestCase):
return x[0] + x[1] return x[0] + x[1]
self.assertTrue(torch.jit.script(sum_i)(4) == 8) self.assertTrue(torch.jit.script(sum_i)(4) == 8)
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.) self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
def test_parse_ir_annotate(self): def test_parse_ir_annotate(self):
ir = """ ir = """
@ -397,7 +440,6 @@ class TestMisc(JitTestCase):
self.assertTrue(ret.numel() == 1) self.assertTrue(ret.numel() == 1)
self.assertTrue(len(ret.size()) == 1) self.assertTrue(len(ret.size()) == 1)
def test_script_many_decorators(self): def test_script_many_decorators(self):
def no_op_decorator(f): def no_op_decorator(f):
return f return f
@ -410,7 +452,9 @@ class TestMisc(JitTestCase):
def foo(x, dim: int): def foo(x, dim: int):
return x.unsqueeze(dim) return x.unsqueeze(dim)
x = torch.randn(1,) x = torch.randn(
1,
)
expected = foo(x, 0) expected = foo(x, 0)
scripted = torch.jit.script(foo) scripted = torch.jit.script(foo)
actual = scripted(x, 0) actual = scripted(x, 0)
@ -421,10 +465,10 @@ class TestMisc(JitTestCase):
# https://github.com/pytorch/pytorch/issues/75476 # https://github.com/pytorch/pytorch/issues/75476
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
p = torch.sigmoid(p) p = torch.sigmoid(p)
result = p ** gamma result = p**gamma
return result return result
x = torch.rand((2, 2), dtype=torch.half, device='cuda') x = torch.rand((2, 2), dtype=torch.half, device="cuda")
ref = fn(x) ref = fn(x)
@ -450,8 +494,12 @@ class TestMisc(JitTestCase):
# We want "Scalar" to come before "complex". # We want "Scalar" to come before "complex".
op, override_names = torch._C._jit_get_operation("aten::add") op, override_names = torch._C._jit_get_operation("aten::add")
print(override_names) print(override_names)
complex_indices = [i for i, name in enumerate(override_names) if name == "complex"] complex_indices = [
Scalar_indices = [i for i, name in enumerate(override_names) if name == "Scalar"] i for i, name in enumerate(override_names) if name == "complex"
]
Scalar_indices = [
i for i, name in enumerate(override_names) if name == "Scalar"
]
self.assertTrue(len(complex_indices) > 0) self.assertTrue(len(complex_indices) > 0)
self.assertTrue(len(Scalar_indices) > 0) self.assertTrue(len(Scalar_indices) > 0)

View File

@ -3,27 +3,33 @@
import os import os
import sys import sys
import unittest import unittest
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, ProfilingMode,
set_default_dtype,
)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
GRAPH_EXECUTOR,
ProfilingMode,
set_default_dtype,
)
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
try: try:
import torchvision import torchvision
HAS_TORCHVISION = True HAS_TORCHVISION = True
except ImportError: except ImportError:
HAS_TORCHVISION = False HAS_TORCHVISION = False
@ -31,6 +37,7 @@ except RuntimeError:
HAS_TORCHVISION = False HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
class MnistNet(nn.Module): class MnistNet(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -49,6 +56,7 @@ class MnistNet(nn.Module):
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
class TestModels(JitTestCase): class TestModels(JitTestCase):
@staticmethod @staticmethod
def _test_dcgan_models(self, device, check_export_import=True): def _test_dcgan_models(self, device, check_export_import=True):
@ -102,31 +110,38 @@ class TestModels(JitTestCase):
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4 # state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid() nn.Sigmoid(),
) )
def forward(self, input): def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1) return self.main(input).view(-1, 1).squeeze(1)
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10 bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device), self.checkTrace(
(torch.rand(bs, nz, 1, 1, device=device),), DCGANGenerator(nz, ngf, nc).to(device),
export_import=check_export_import) (torch.rand(bs, nz, 1, 1, device=device),),
example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device)) export_import=check_export_import,
self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,), )
export_import=check_export_import) example_input = DCGANGenerator(nz, ngf, nc).to(device)(
torch.rand(bs, nz, 1, 1, device=device)
)
self.checkTrace(
DCGANDiscriminator(nc, ndf).to(device),
(example_input,),
export_import=check_export_import,
)
def test_dcgan_models(self): def test_dcgan_models(self):
# Note: Can sometimes fail with low precision if run with float dtype # Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double): with set_default_dtype(torch.double):
self._test_dcgan_models(self, device='cpu') self._test_dcgan_models(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_dcgan_models_cuda(self): def test_dcgan_models_cuda(self):
# Note: Can sometimes fail with low precision if run with float dtype # Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double): with set_default_dtype(torch.double):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_dcgan_models(self, device='cuda', check_export_import=False) self._test_dcgan_models(self, device="cuda", check_export_import=False)
@staticmethod @staticmethod
def _test_neural_style(self, device, check_export_import=True): def _test_neural_style(self, device, check_export_import=True):
@ -147,9 +162,13 @@ class TestModels(JitTestCase):
self.res4 = ResidualBlock(128) self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128) self.res5 = ResidualBlock(128)
# Upsampling Layers # Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2) self.deconv1 = UpsampleConvLayer(
128, 64, kernel_size=3, stride=1, upsample=2
)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True) self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2) self.deconv2 = UpsampleConvLayer(
64, 32, kernel_size=3, stride=1, upsample=2
)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True) self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1) self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities # Non-linearities
@ -174,7 +193,9 @@ class TestModels(JitTestCase):
super().__init__() super().__init__()
reflection_padding = kernel_size // 2 reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) self.conv2d = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride
)
def forward(self, x): def forward(self, x):
out = self.reflection_pad(x) out = self.reflection_pad(x)
@ -209,14 +230,20 @@ class TestModels(JitTestCase):
ref: http://distill.pub/2016/deconv-checkerboard/ ref: http://distill.pub/2016/deconv-checkerboard/
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): def __init__(
self, in_channels, out_channels, kernel_size, stride, upsample=None
):
super().__init__() super().__init__()
self.upsample = upsample self.upsample = upsample
if upsample: if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample) self.upsample_layer = torch.nn.Upsample(
mode="nearest", scale_factor=upsample
)
reflection_padding = kernel_size // 2 reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding) self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride) self.conv2d = torch.nn.Conv2d(
in_channels, out_channels, kernel_size, stride
)
def forward(self, x): def forward(self, x):
x_in = x x_in = x
@ -226,44 +253,54 @@ class TestModels(JitTestCase):
out = self.conv2d(out) out = self.conv2d(out)
return out return out
self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import) self.checkTrace(
TransformerNet(),
(torch.rand(5, 3, 16, 16),),
export_import=check_export_import,
)
@slowTest @slowTest
def test_neural_style(self): def test_neural_style(self):
self._test_neural_style(self, device='cpu') self._test_neural_style(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_neural_style_cuda(self): def test_neural_style_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_neural_style(self, device='cuda', check_export_import=False) self._test_neural_style(self, device="cuda", check_export_import=False)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor") @unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
)
@staticmethod @staticmethod
def _test_mnist(self, device, check_export_import=True): def _test_mnist(self, device, check_export_import=True):
# eval() is present because dropout makes this nondeterministic # eval() is present because dropout makes this nondeterministic
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),), self.checkTrace(
export_import=check_export_import) MnistNet().to(device).eval(),
(torch.rand(5, 1, 28, 28, device=device),),
export_import=check_export_import,
)
def test_mnist(self): def test_mnist(self):
self._test_mnist(self, device='cpu') self._test_mnist(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_cuda(self): def test_mnist_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_mnist(self, device='cuda', check_export_import=False) self._test_mnist(self, device="cuda", check_export_import=False)
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_training_leaks_no_memory_cuda(self): def test_mnist_training_leaks_no_memory_cuda(self):
net = MnistNet().cuda() net = MnistNet().cuda()
# MnistNet uses dropout, don't check its trace # MnistNet uses dropout, don't check its trace
traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')], traced_net = torch.jit.trace(
check_trace=False) net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
)
def train(iters): def train(iters):
for _ in range(iters): for _ in range(iters):
# Get some fake data # Get some fake data
inp = torch.randn(5, 1, 28, 28, device='cuda') inp = torch.randn(5, 1, 28, 28, device="cuda")
out = traced_net(inp) out = traced_net(inp)
# Here's some fake loss # Here's some fake loss
@ -292,21 +329,23 @@ class TestModels(JitTestCase):
return F.softmax(action_scores, dim=1) return F.softmax(action_scores, dim=1)
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),), self.checkTrace(
export_import=test_export_import) Policy().to(device),
(torch.rand(1, 4, device=device),),
export_import=test_export_import,
)
def test_reinforcement_learning(self): def test_reinforcement_learning(self):
self._test_reinforcement_learning(self, device='cpu') self._test_reinforcement_learning(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_reinforcement_learning_cuda(self): def test_reinforcement_learning_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_reinforcement_learning(self, device='cuda', test_export_import=False) self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
@staticmethod @staticmethod
def _test_snli(self, device, check_export_import=True): def _test_snli(self, device, check_export_import=True):
class Bottle(nn.Module): class Bottle(nn.Module):
def forward(self, input): def forward(self, input):
if len(input.size()) <= 2: if len(input.size()) <= 2:
return super().forward(input) return super().forward(input)
@ -318,25 +357,31 @@ class TestModels(JitTestCase):
pass pass
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
input_size = config.d_proj if config.projection else config.d_embed input_size = config.d_proj if config.projection else config.d_embed
dropout = 0 if config.n_layers == 1 else config.dp_ratio dropout = 0 if config.n_layers == 1 else config.dp_ratio
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden, self.rnn = nn.LSTM(
num_layers=config.n_layers, dropout=dropout, input_size=input_size,
bidirectional=config.birnn) hidden_size=config.d_hidden,
num_layers=config.n_layers,
dropout=dropout,
bidirectional=config.birnn,
)
def forward(self, inputs): def forward(self, inputs):
batch_size = inputs.size()[1] batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = inputs.new_zeros(state_shape) h0 = c0 = inputs.new_zeros(state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1) return (
ht[-1]
if not self.config.birnn
else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
)
class SNLIClassifier(nn.Module): class SNLIClassifier(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
@ -359,7 +404,8 @@ class TestModels(JitTestCase):
Linear(*lin_config), Linear(*lin_config),
self.relu, self.relu,
self.dropout, self.dropout,
Linear(seq_in_size, config.d_out)) Linear(seq_in_size, config.d_out),
)
def forward(self, premise, hypothesis): def forward(self, premise, hypothesis):
prem_embed = self.embed(premise) prem_embed = self.embed(premise)
@ -391,22 +437,25 @@ class TestModels(JitTestCase):
premise = torch.LongTensor(48, 64).random_(0, 100).to(device) premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device) hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis), self.checkTrace(
inputs_require_grads=False, export_import=check_export_import) SNLIClassifier(Config()).to(device),
(premise, hypothesis),
inputs_require_grads=False,
export_import=check_export_import,
)
@slowTest @slowTest
def test_snli(self): def test_snli(self):
self._test_snli(self, device='cpu') self._test_snli(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self): def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_snli(self, device='cuda', check_export_import=False) self._test_snli(self, device="cuda", check_export_import=False)
@staticmethod @staticmethod
def _test_super_resolution(self, device, check_export_import=True): def _test_super_resolution(self, device, check_export_import=True):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, upscale_factor): def __init__(self, upscale_factor):
super().__init__() super().__init__()
@ -414,7 +463,7 @@ class TestModels(JitTestCase):
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x): def forward(self, x):
@ -425,17 +474,20 @@ class TestModels(JitTestCase):
return x return x
net = Net(upscale_factor=4).to(device) net = Net(upscale_factor=4).to(device)
self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),), self.checkTrace(
export_import=check_export_import) net,
(torch.rand(5, 1, 32, 32, device=device),),
export_import=check_export_import,
)
@slowTest @slowTest
def test_super_resolution(self): def test_super_resolution(self):
self._test_super_resolution(self, device='cpu') self._test_super_resolution(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, 'no CUDA') @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_super_resolution_cuda(self): def test_super_resolution_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_super_resolution(self, device='cuda', check_export_import=False) self._test_super_resolution(self, device="cuda", check_export_import=False)
@suppress_warnings @suppress_warnings
def test_time_sequence_prediction(self): def test_time_sequence_prediction(self):
@ -485,8 +537,7 @@ class TestModels(JitTestCase):
# disabled due to a jitter issues that will be fixed by using load/store in the compiler # disabled due to a jitter issues that will be fixed by using load/store in the compiler
with torch._jit_internal._disable_emit_hooks(): with torch._jit_internal._disable_emit_hooks():
# TODO: toggle export_import once above issues are fixed # TODO: toggle export_import once above issues are fixed
self.checkTrace(Traced(), (torch.rand(3, 4),), self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
export_import=False)
@staticmethod @staticmethod
def _test_vae(self, device, check_export_import=True): def _test_vae(self, device, check_export_import=True):
@ -523,22 +574,27 @@ class TestModels(JitTestCase):
with enable_profiling_mode_for_profiling_tests(): with enable_profiling_mode_for_profiling_tests():
# eval() is present because randn_like makes this nondeterministic # eval() is present because randn_like makes this nondeterministic
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),), self.checkTrace(
export_import=check_export_import) VAE().to(device).eval(),
(torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import,
)
def test_vae(self): def test_vae(self):
self._test_vae(self, device='cpu') self._test_vae(self, device="cpu")
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self): def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480) # XXX: export_import on CUDA modules doesn't work (#11480)
self._test_vae(self, device='cuda', check_export_import=False) self._test_vae(self, device="cuda", check_export_import=False)
@slowTest @slowTest
@skipIfNoTorchVision @skipIfNoTorchVision
def test_script_module_trace_resnet18(self): def test_script_module_trace_resnet18(self):
x = torch.ones(1, 3, 224, 224) x = torch.ones(1, 3, 224, 224)
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)) m_orig = torch.jit.trace(
torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
)
m_import = self.getExportImportCopy(m_orig) m_import = self.getExportImportCopy(m_orig)
input = torch.randn(1, 3, 224, 224, requires_grad=True) input = torch.randn(1, 3, 224, 224, requires_grad=True)
@ -559,16 +615,24 @@ class TestModels(JitTestCase):
def test_script_module_script_resnet(self): def test_script_module_script_resnet(self):
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution""" """1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
)
def conv3x3(in_planes, out_planes, stride=1): def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(
padding=1, bias=False) in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
class BasicBlock(torch.jit.ScriptModule): class BasicBlock(torch.jit.ScriptModule):
expansion = 1 expansion = 1
__constants__ = ['downsample'] __constants__ = ["downsample"]
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__() super().__init__()
@ -600,13 +664,14 @@ class TestModels(JitTestCase):
return out return out
class ResNet(torch.jit.ScriptModule): class ResNet(torch.jit.ScriptModule):
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4'] __constants__ = ["layer1", "layer2", "layer3", "layer4"]
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000):
super().__init__() super().__init__()
self.inplanes = 64 self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, self.conv1 = nn.Conv2d(
bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -619,7 +684,9 @@ class TestModels(JitTestCase):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
@ -679,8 +746,10 @@ class TestModels(JitTestCase):
x = torch.ones(1, 3, 224, 224) x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet() model = torchvision.models.AlexNet()
with torch.random.fork_rng(devices=[]): with torch.random.fork_rng(devices=[]):
g, outputs, inputs = torch.jit._get_trace_graph(model, x, return_inputs=True) g, outputs, inputs = torch.jit._get_trace_graph(
self.run_pass('cse', g) model, x, return_inputs=True
)
self.run_pass("cse", g)
m = self.createFunctionFromGraph(g) m = self.createFunctionFromGraph(g)
with torch.random.fork_rng(devices=[]): with torch.random.fork_rng(devices=[]):
self.assertEqual(outputs, m(*inputs)) self.assertEqual(outputs, m(*inputs))

View File

@ -1,19 +1,23 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch
import os import os
import sys import sys
from typing import Any, Dict, List
import torch
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
from typing import Dict, Any, List
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleAPIs(JitTestCase): class TestModuleAPIs(JitTestCase):
def test_default_state_dict_methods(self): def test_default_state_dict_methods(self):
@ -52,18 +56,23 @@ class TestModuleAPIs(JitTestCase):
return x return x
@torch.jit.export @torch.jit.export
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor], def _save_to_state_dict(
prefix: str, keep_vars: bool): self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)} return {"dummy": torch.ones(1)}
@torch.jit.export @torch.jit.export
def _load_from_state_dict(self, def _load_from_state_dict(
state_dict: Dict[str, torch.Tensor], self,
prefix: str, local_metadata: Any, state_dict: Dict[str, torch.Tensor],
strict: bool, missing_keys: List[str], prefix: str,
unexpected_keys: List[str], local_metadata: Any,
error_msgs: List[str]): strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True self.customized_load_state_dict_called = True
return return
@ -94,18 +103,23 @@ class TestModuleAPIs(JitTestCase):
return x return x
@torch.jit.export @torch.jit.export
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor], def _save_to_state_dict(
prefix: str, keep_vars: bool): self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
):
self.customized_save_state_dict_called = True self.customized_save_state_dict_called = True
return {"dummy": torch.ones(1)} return {"dummy": torch.ones(1)}
@torch.jit.export @torch.jit.export
def _load_from_state_dict(self, def _load_from_state_dict(
state_dict: Dict[str, torch.Tensor], self,
prefix: str, local_metadata: Any, state_dict: Dict[str, torch.Tensor],
strict: bool, missing_keys: List[str], prefix: str,
unexpected_keys: List[str], local_metadata: Any,
error_msgs: List[str]): strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
self.customized_load_state_dict_called = True self.customized_load_state_dict_called = True
return return

View File

@ -2,9 +2,10 @@
import os import os
import sys import sys
from collections import OrderedDict
from typing import Any, List, Tuple from typing import Any, List, Tuple
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
@ -13,10 +14,13 @@ from torch.testing._internal.jit_utils import JitTestCase
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModuleContainers(JitTestCase): class TestModuleContainers(JitTestCase):
def test_sequential_intermediary_types(self): def test_sequential_intermediary_types(self):
@ -54,11 +58,13 @@ class TestModuleContainers(JitTestCase):
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
modules = OrderedDict([ modules = OrderedDict(
('one', Inner()), [
('two', Inner2()), ("one", Inner()),
('three', Inner3()), ("two", Inner2()),
]) ("three", Inner3()),
]
)
self.moduledict = nn.ModuleDict(modules) self.moduledict = nn.ModuleDict(modules)
def forward(self, x, skip_name): def forward(self, x, skip_name):
@ -115,7 +121,6 @@ class TestModuleContainers(JitTestCase):
return x, x2, names, iter return x, x2, names, iter
for name in ["", "one", "two", "three"]: for name in ["", "one", "two", "three"]:
inp = torch.tensor(1) inp = torch.tensor(1)
self.checkModule(M(), (inp, name)) self.checkModule(M(), (inp, name))
@ -136,7 +141,7 @@ class TestModuleContainers(JitTestCase):
x = mod(x) x = mod(x)
return x - 5 return x - 5
self.checkModule(CustomSequential(), (torch.tensor(.5),)) self.checkModule(CustomSequential(), (torch.tensor(0.5),))
class CustomModuleList(nn.ModuleList): class CustomModuleList(nn.ModuleList):
def __init__(self): def __init__(self):
@ -148,16 +153,19 @@ class TestModuleContainers(JitTestCase):
x = mod(x) x = mod(x)
return x - 5 return x - 5
self.checkModule(CustomModuleList(), (torch.tensor(.5),)) self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
class CustomModuleDict(nn.ModuleDict): class CustomModuleDict(nn.ModuleDict):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
OrderedDict([ OrderedDict(
('one', Inner()), [
('two', nn.ReLU()), ("one", Inner()),
('three', Inner()), ("two", nn.ReLU()),
])) ("three", Inner()),
]
)
)
def forward(self, x): def forward(self, x):
x = x + 3 x = x + 3
@ -167,7 +175,7 @@ class TestModuleContainers(JitTestCase):
names.append(name) names.append(name)
return names, x - 5 return names, x - 5
self.checkModule(CustomModuleDict(), (torch.tensor(.5),)) self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
def test_script_module_list_sequential(self): def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule): class M(torch.jit.ScriptModule):
@ -225,7 +233,9 @@ class TestModuleContainers(JitTestCase):
def forward(self, v): def forward(self, v):
return self.mods[-11].forward(v) return self.mods[-11].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Index -11 out of range", "self.mods[-11]"): with self.assertRaisesRegexWithHighlight(
Exception, "Index -11 out of range", "self.mods[-11]"
):
torch.jit.script(M2()) torch.jit.script(M2())
class M3(M): class M3(M):
@ -233,7 +243,9 @@ class TestModuleContainers(JitTestCase):
i = 3 i = 3
return self.mods[i].forward(v) return self.mods[i].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Enumeration is supported", "self.mods[i]"): with self.assertRaisesRegexWithHighlight(
Exception, "Enumeration is supported", "self.mods[i]"
):
torch.jit.script(M3()) torch.jit.script(M3())
class M4(M): class M4(M):
@ -273,17 +285,23 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict({"submod": self.submod}) self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs): def forward(self, inputs):
assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList" assert (
self.modulelist[0] is self.submod
), "__getitem__ failing for ModuleList"
assert len(self.modulelist) == 1, "__len__ failing for ModuleList" assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist: for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList" assert module is self.submod, "__iter__ failing for ModuleList"
assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential" assert (
self.sequential[0] is self.submod
), "__getitem__ failing for Sequential"
assert len(self.sequential) == 1, "__len__ failing for Sequential" assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential: for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential" assert module is self.submod, "__iter__ failing for Sequential"
assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict" assert (
self.moduledict["submod"] is self.submod
), "__getitem__ failing for ModuleDict"
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict" assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently # note: unable to index moduledict with a string variable currently
@ -345,12 +363,13 @@ class TestModuleContainers(JitTestCase):
super().__init__() super().__init__()
self.relu = torch.jit.script(torch.nn.ReLU()) self.relu = torch.jit.script(torch.nn.ReLU())
self.tanh = torch.jit.script(torch.nn.Tanh()) self.tanh = torch.jit.script(torch.nn.Tanh())
self.moduledict = torch.nn.ModuleDict({"relu": self.relu, self.moduledict = torch.nn.ModuleDict(
"tanh": self.tanh}) {"relu": self.relu, "tanh": self.tanh}
)
def forward(self, input): def forward(self, input):
assert self.moduledict['relu'] is self.relu assert self.moduledict["relu"] is self.relu
assert self.moduledict['tanh'] is self.tanh assert self.moduledict["tanh"] is self.tanh
return input return input
m = MyModule() m = MyModule()
@ -360,31 +379,34 @@ class TestModuleContainers(JitTestCase):
class BadModule(torch.nn.Module): class BadModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None, self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
"bar": None})
def forward(self, input): def forward(self, input):
assert self.moduledict['blah'] == "blah", "this is a keyerror" assert self.moduledict["blah"] == "blah", "this is a keyerror"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Key Error, blah", "self.moduledict['blah'"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
):
b = BadModule() b = BadModule()
torch.jit.script(b) torch.jit.script(b)
class AnotherBadModule(torch.nn.Module): class AnotherBadModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None, self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
"bar": None})
def forward(self, input): def forward(self, input):
idx = 'blah' idx = "blah"
assert self.moduledict[idx] == "blah", "this is a string literal error" assert self.moduledict[idx] == "blah", "this is a string literal error"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Unable to extract string literal index. " with self.assertRaisesRegexWithHighlight(
"ModuleDict indexing is only supported with string literals. " RuntimeError,
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail " "Unable to extract string literal index. "
"because i is not a literal.", "ModuleDict indexing is only supported with string literals. "
"self.moduledict[idx]"): "For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
"because i is not a literal.",
"self.moduledict[idx]",
):
b = AnotherBadModule() b = AnotherBadModule()
torch.jit.script(b) torch.jit.script(b)
@ -393,6 +415,7 @@ class TestModuleContainers(JitTestCase):
Test that an attempt to script a module with a regular list attribute Test that an attempt to script a module with a regular list attribute
containing other modules fails with a relevant error message. containing other modules fails with a relevant error message.
""" """
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -422,7 +445,9 @@ class TestModuleContainers(JitTestCase):
self.moduledict = CustomModuleDict() self.moduledict = CustomModuleDict()
def forward(self, inputs): def forward(self, inputs):
assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict" assert (
"submod" not in self.moduledict
), "__contains__ fails for ModuleDict"
return inputs return inputs
m = MyModule() m = MyModule()
@ -433,6 +458,7 @@ class TestModuleContainers(JitTestCase):
Test that a type annotation can be provided for a ModuleDict that allows Test that a type annotation can be provided for a ModuleDict that allows
non-static indexing. non-static indexing.
""" """
@torch.jit.interface @torch.jit.interface
class ModuleInterface(torch.nn.Module): class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any: def forward(self, inp: Any) -> Any:
@ -485,7 +511,9 @@ class TestModuleContainers(JitTestCase):
submodule: ModuleInterface = self.d[key] submodule: ModuleInterface = self.d[key]
return submodule.forward(x) return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
):
torch.jit.script(ModWithWrongAnnotation()) torch.jit.script(ModWithWrongAnnotation())
def test_typed_module_list(self): def test_typed_module_list(self):
@ -493,6 +521,7 @@ class TestModuleContainers(JitTestCase):
Test that a type annotation can be provided for a ModuleList that allows Test that a type annotation can be provided for a ModuleList that allows
non-static indexing. non-static indexing.
""" """
@torch.jit.interface @torch.jit.interface
class ModuleInterface(torch.nn.Module): class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any: def forward(self, inp: Any) -> Any:
@ -545,7 +574,9 @@ class TestModuleContainers(JitTestCase):
submodule: ModuleInterface = self.l[idx] submodule: ModuleInterface = self.l[idx]
return submodule.forward(x) return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
):
torch.jit.script(ModWithWrongAnnotation()) torch.jit.script(ModWithWrongAnnotation())
def test_module_properties(self): def test_module_properties(self):
@ -596,10 +627,34 @@ class TestModuleContainers(JitTestCase):
def attr(self): def attr(self):
return self.a + 1 return self.a + 1
self.checkModule(ModuleWithProperties(5), (5, 6,)) self.checkModule(
self.checkModule(ModuleWithProperties(5), (-5, -6,)) ModuleWithProperties(5),
self.checkModule(ModuleWithNoSetter(5), (5, 6,)) (
self.checkModule(ModuleWithNoSetter(5), (-5, -6,)) 5,
6,
),
)
self.checkModule(
ModuleWithProperties(5),
(
-5,
-6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
5,
6,
),
)
self.checkModule(
ModuleWithNoSetter(5),
(
-5,
-6,
),
)
mod = ModuleWithProperties(3) mod = ModuleWithProperties(3)
scripted_mod = torch.jit.script(mod) scripted_mod = torch.jit.script(mod)
@ -625,7 +680,6 @@ class TestModuleContainers(JitTestCase):
def forward(self, x): def forward(self, x):
return self.linear(self.linear(x)) return self.linear(self.linear(x))
class N(nn.Module): class N(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -659,7 +713,9 @@ class TestModuleContainers(JitTestCase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)]) self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x): def forward(self, x):
self.module_list[0] self.module_list[0]
@ -673,7 +729,9 @@ class TestModuleContainers(JitTestCase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)]) self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)]) self.parameter_list = nn.ParameterList(
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
)
def forward(self, x): def forward(self, x):
r = x r = x
@ -687,9 +745,14 @@ class TestModuleContainers(JitTestCase):
class MyModule(nn.Module): class MyModule(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.parameter_dict = nn.ParameterDict({k: nn.Parameter(torch.zeros(1)) for k in ['a', 'b', 'c']}) self.parameter_dict = nn.ParameterDict(
{k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
)
def forward(self, x): def forward(self, x):
return self.parameter_dict['a'] * x + self.parameter_dict['b'] * self.parameter_dict['c'] return (
self.parameter_dict["a"] * x
+ self.parameter_dict["b"] * self.parameter_dict["c"]
)
self.checkModule(MyModule(), (torch.ones(1),)) self.checkModule(MyModule(), (torch.ones(1),))

View File

@ -1,10 +1,11 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from typing import List, Any
import torch
import torch.nn as nn
import os import os
import sys import sys
from typing import Any, List
import torch
import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.testing._internal.jit_utils import JitTestCase, make_global
@ -12,10 +13,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class OrigModule(nn.Module): class OrigModule(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
@ -27,6 +31,7 @@ class OrigModule(nn.Module):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1 return input + self.one(input, input) + 1
class NewModule(nn.Module): class NewModule(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor: def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1 return inp1 * inp2 + 1
@ -34,6 +39,7 @@ class NewModule(nn.Module):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1) return self.one(input, input + 1)
class TestModuleInterface(JitTestCase): class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self): def test_not_submodule_interface_call(self):
@torch.jit.interface @torch.jit.interface
@ -42,7 +48,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestNotModuleInterfaceCall(nn.Module): class TestNotModuleInterfaceCall(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -51,7 +57,9 @@ class TestModuleInterface(JitTestCase):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.two(input) return self.proxy_mod.two(input)
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.proxy_mod.two"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "object has no attribute or method", "self.proxy_mod.two"
):
torch.jit.script(TestNotModuleInterfaceCall()) torch.jit.script(TestNotModuleInterfaceCall())
def test_module_interface(self): def test_module_interface(self):
@ -108,17 +116,37 @@ class TestModuleInterface(JitTestCase):
scripted_foo_mod = torch.jit.script(FooMod()) scripted_foo_mod = torch.jit.script(FooMod())
scripted_bar_mod = torch.jit.script(BarMod()) scripted_bar_mod = torch.jit.script(BarMod())
self.checkScript(use_module_interface, self.checkScript(
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) use_module_interface,
self.checkScript(use_class_interface, (
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),)) [scripted_foo_mod, scripted_bar_mod],
torch.rand(3, 4),
),
)
self.checkScript(
use_class_interface,
(
[scripted_foo_mod, scripted_bar_mod],
torch.rand(3, 4),
),
)
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor: def call_module_interface_on_other_method(
mod_interface: OneTwoModule, x: Tensor
) -> Tensor:
return mod_interface.forward2(x) return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified. # ensure error out when we call the module on the method other than the interface specified.
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "mod_interface.forward2"): with self.assertRaisesRegexWithHighlight(
self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),)) RuntimeError, "object has no attribute or method", "mod_interface.forward2"
):
self.checkScript(
call_module_interface_on_other_method,
(
scripted_bar_mod,
torch.rand(3, 4),
),
)
def test_module_doc_string(self): def test_module_doc_string(self):
@torch.jit.interface @torch.jit.interface
@ -135,7 +163,7 @@ class TestModuleInterface(JitTestCase):
r"""stuff 3""" r"""stuff 3"""
class TestModule(nn.Module): class TestModule(nn.Module):
proxy_mod : TestInterface proxy_mod: TestInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -178,7 +206,9 @@ class TestModuleInterface(JitTestCase):
return self.one(self.two(x), x) return self.one(self.two(x), x)
# check class object is not a subtype of module interface # check class object is not a subtype of module interface
with self.assertRaisesRegex(RuntimeError, "ScriptModule class can be subtype of module interface"): with self.assertRaisesRegex(
RuntimeError, "ScriptModule class can be subtype of module interface"
):
as_module_interface(Foo()) as_module_interface(Foo())
class WrongMod(nn.Module): class WrongMod(nn.Module):
@ -233,9 +263,11 @@ class TestModuleInterface(JitTestCase):
as_tensor_to_any(torch.jit.script(TensorToAnyImplB())) as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
as_any_to_any(torch.jit.script(AnyToAnyImpl())) as_any_to_any(torch.jit.script(AnyToAnyImpl()))
def test_module_interface_inheritance(self): def test_module_interface_inheritance(self):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"): with self.assertRaisesRegex(
RuntimeError, "does not support inheritance yet. Please directly"
):
@torch.jit.interface @torch.jit.interface
class InheritMod(nn.ReLU): class InheritMod(nn.ReLU):
def three(self, x: Tensor) -> Tensor: def three(self, x: Tensor) -> Tensor:
@ -251,7 +283,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(nn.Module): class TestModule(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -269,7 +301,9 @@ class TestModuleInterface(JitTestCase):
self.assertEqual(scripted_mod(input), input * (input + 1) + 1) self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# module swap with non-scripted module should throw error # module swap with non-scripted module should throw error
with self.assertRaisesRegex(RuntimeError, "a ScriptModule with non-scripted module"): with self.assertRaisesRegex(
RuntimeError, "a ScriptModule with non-scripted module"
):
scripted_mod.proxy_mod = NewModule() scripted_mod.proxy_mod = NewModule()
def test_module_swap_wrong_module(self): def test_module_swap_wrong_module(self):
@ -286,7 +320,7 @@ class TestModuleInterface(JitTestCase):
return input + 1 return input + 1
class TestModule(nn.Module): class TestModule(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -310,7 +344,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(nn.Module): class TestModule(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -358,9 +392,11 @@ class TestModuleInterface(JitTestCase):
# proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed. # proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule()) scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
# proxy_mod is neither a module interface or have the same JIT type, should fail # proxy_mod is neither a module interface or have the same JIT type, should fail
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " + RuntimeError,
r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"): r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' "
+ r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'",
):
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule()) scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
def test_script_module_as_interface_swap(self): def test_script_module_as_interface_swap(self):
@ -391,7 +427,7 @@ class TestModuleInterface(JitTestCase):
return self.one(input, input + 1) return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module): class TestNNModuleWithScriptModule(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -432,7 +468,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
proxy_mod : ModInterface proxy_mod: ModInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -480,7 +516,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
proxy_mod : ModInterface proxy_mod: ModInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -523,7 +559,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
proxy_mod : ModInterface proxy_mod: ModInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -568,7 +604,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
proxy_mod : ModInterface proxy_mod: ModInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -583,7 +619,9 @@ class TestModuleInterface(JitTestCase):
m = torch.jit.script(TestModule()) m = torch.jit.script(TestModule())
m.eval() m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type."): with self.assertRaisesRegex(
RuntimeError, "Freezing does not support SetAttr on an interface type."
):
mf = torch._C._freeze_module(m._c, freezeInterfaces=True) mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_freeze_module_with_interface_and_fork(self): def test_freeze_module_with_interface_and_fork(self):
@ -610,7 +648,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
proxy_mod : ModInterface proxy_mod: ModInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -644,7 +682,7 @@ class TestModuleInterface(JitTestCase):
pass pass
class TestModule(nn.Module): class TestModule(nn.Module):
proxy_mod : ModuleInterface proxy_mod: ModuleInterface
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1,18 +1,22 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch
import os import os
import sys import sys
import torch
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestModules(JitTestCase): class TestModules(JitTestCase):
def test_script_module_with_constants_list(self): def test_script_module_with_constants_list(self):

View File

@ -4,10 +4,13 @@ import torch
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestOpDecompositions(JitTestCase): class TestOpDecompositions(JitTestCase):
def test_op_decomposition(self): def test_op_decomposition(self):
@ -31,7 +34,9 @@ class TestOpDecompositions(JitTestCase):
def square_decomp(x): def square_decomp(x):
return torch.pow(x, 2) return torch.pow(x, 2)
torch.jit._register_decomposition(torch.ops.aten.square.default, square_decomp.graph) torch.jit._register_decomposition(
torch.ops.aten.square.default, square_decomp.graph
)
torch._C._jit_pass_run_decompositions(foo.graph) torch._C._jit_pass_run_decompositions(foo.graph)
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph) FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
x = torch.rand([4]) x = torch.rand([4])

View File

@ -3,8 +3,9 @@
import torch import torch
import torch._C import torch._C
import torch.nn.functional as F import torch.nn.functional as F
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoXNNPACK from torch.testing._internal.common_utils import skipIfNoXNNPACK
from torch.testing._internal.jit_utils import JitTestCase
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase): class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
def check_replacement( def check_replacement(
@ -133,10 +134,8 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
"prepacked::linear_clamp_run": "aten::linear", "prepacked::linear_clamp_run": "aten::linear",
"prepacked::conv2d_clamp_prepack": "aten::conv2d", "prepacked::conv2d_clamp_prepack": "aten::conv2d",
"prepacked::conv2d_clamp_run": "aten::conv2d", "prepacked::conv2d_clamp_run": "aten::conv2d",
"prepacked::conv2d_transpose_clamp_prepack": "prepacked::conv2d_transpose_clamp_prepack": "aten::conv_transpose2d",
"aten::conv_transpose2d", "prepacked::conv2d_transpose_clamp_run": "aten::conv_transpose2d",
"prepacked::conv2d_transpose_clamp_run":
"aten::conv_transpose2d",
}, },
jit_pass=torch._C._jit_pass_insert_prepacked_ops, jit_pass=torch._C._jit_pass_insert_prepacked_ops,
) )
@ -147,7 +146,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)), model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
replacements={ replacements={
"prepacked::linear_clamp_prepack": "aten::linear", "prepacked::linear_clamp_prepack": "aten::linear",
"prepacked::linear_clamp_run": "aten::linear" "prepacked::linear_clamp_run": "aten::linear",
}, },
jit_pass=torch._C._jit_pass_insert_prepacked_ops, jit_pass=torch._C._jit_pass_insert_prepacked_ops,
) )
@ -223,11 +222,9 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
self.check_replacement( self.check_replacement(
model=model, model=model,
replacements={ replacements={
"prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack",
"prepacked::linear_clamp_prepack",
"prepacked::linear_clamp_run": linear_activation_kind, "prepacked::linear_clamp_run": linear_activation_kind,
"prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack",
"prepacked::conv2d_clamp_prepack",
"prepacked::conv2d_clamp_run": conv2d_activation_kind, "prepacked::conv2d_clamp_run": conv2d_activation_kind,
}, },
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv, jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
@ -239,7 +236,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.hardtanh, linear_activation=F.hardtanh,
linear_activation_kind="aten::hardtanh", linear_activation_kind="aten::hardtanh",
conv2d_activation=F.hardtanh_, conv2d_activation=F.hardtanh_,
conv2d_activation_kind="aten::hardtanh_" conv2d_activation_kind="aten::hardtanh_",
) )
@skipIfNoXNNPACK @skipIfNoXNNPACK
@ -248,7 +245,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.hardtanh_, linear_activation=F.hardtanh_,
linear_activation_kind="aten::hardtanh_", linear_activation_kind="aten::hardtanh_",
conv2d_activation=F.hardtanh, conv2d_activation=F.hardtanh,
conv2d_activation_kind="aten::hardtanh" conv2d_activation_kind="aten::hardtanh",
) )
@skipIfNoXNNPACK @skipIfNoXNNPACK
@ -257,7 +254,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.relu, linear_activation=F.relu,
linear_activation_kind="aten::relu", linear_activation_kind="aten::relu",
conv2d_activation=F.relu_, conv2d_activation=F.relu_,
conv2d_activation_kind="aten::relu_" conv2d_activation_kind="aten::relu_",
) )
@skipIfNoXNNPACK @skipIfNoXNNPACK
@ -266,5 +263,5 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
linear_activation=F.relu_, linear_activation=F.relu_,
linear_activation_kind="aten::relu_", linear_activation_kind="aten::relu_",
conv2d_activation=F.relu, conv2d_activation=F.relu,
conv2d_activation_kind="aten::relu" conv2d_activation_kind="aten::relu",
) )

View File

@ -2,16 +2,19 @@
import torch import torch
from torch import nn
import torch.nn.utils.parametrize as parametrize import torch.nn.utils.parametrize as parametrize
from torch import nn
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestParametrization(JitTestCase): class TestParametrization(JitTestCase):
# Define some parametrization # Define some parametrization
@ -29,7 +32,7 @@ class TestParametrization(JitTestCase):
# Check the tracing works. Because traced functions cannot be called # Check the tracing works. Because traced functions cannot be called
# directly, we run the comparison on the activations. # directly, we run the comparison on the activations.
traced_model = torch.jit.trace_module(model, {'forward': x}) traced_model = torch.jit.trace_module(model, {"forward": x})
y_hat = traced_model(x) y_hat = traced_model(x)
self.assertEqual(y, y_hat) self.assertEqual(y, y_hat)
@ -39,10 +42,9 @@ class TestParametrization(JitTestCase):
self.assertEqual(y, y_hat) self.assertEqual(y, y_hat)
# Check the tracing throws an error when caching # Check the tracing throws an error when caching
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(RuntimeError, "Cannot trace a model while caching"):
'Cannot trace a model while caching'):
with parametrize.cached(): with parametrize.cached():
traced_model = torch.jit.trace_module(model, {'forward': x}) traced_model = torch.jit.trace_module(model, {"forward": x})
def test_scriptable(self): def test_scriptable(self):
# TODO: Need to fix the scripting in parametrizations # TODO: Need to fix the scripting in parametrizations
@ -65,5 +67,5 @@ class TestParametrization(JitTestCase):
self.assertEqual(y, y_hat) self.assertEqual(y, y_hat)
# Check the scripting process throws an error when caching # Check the scripting process throws an error when caching
with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'): with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
scripted_model = torch.jit.trace_module(model) scripted_model = torch.jit.trace_module(model)

View File

@ -2,18 +2,22 @@
import os import os
import sys import sys
from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
import torch import torch
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
from typing import List, Dict, Tuple, Any, Optional, NamedTuple # noqa: F401
from torch.testing._internal.common_utils import NoTest from torch.testing._internal.common_utils import NoTest
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if not _IS_MONKEYTYPE_INSTALLED: if not _IS_MONKEYTYPE_INSTALLED:
print("monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr) print(
"monkeytype is not installed. Skipping tests for Profile-Directed Typing",
file=sys.stderr,
)
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
if __name__ == "__main__": if __name__ == "__main__":
@ -23,10 +27,12 @@ if __name__ == "__main__":
"instead." "instead."
) )
class TestPDT(JitTestCase): class TestPDT(JitTestCase):
""" """
A suite of tests for profile directed typing in TorchScript. A suite of tests for profile directed typing in TorchScript.
""" """
def test_nn_module(self): def test_nn_module(self):
class TestPDTModel(torch.nn.Module): class TestPDTModel(torch.nn.Module):
def forward(self, x) -> Any: def forward(self, x) -> Any:
@ -39,8 +45,14 @@ class TestPDT(JitTestCase):
make_global(TestPDTModel) make_global(TestPDTModel)
pdt_model = TestPDTModel() pdt_model = TestPDTModel()
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ] inp: List[Tuple[Any, ...]] = [
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) (20,),
(2.7,),
(False,),
]
scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(50), pdt_model(50)) self.assertEqual(scripted_pdt_model(50), pdt_model(50))
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
self.assertTrue(scripted_pdt_model(True), pdt_model(True)) self.assertTrue(scripted_pdt_model(True), pdt_model(True))
@ -63,8 +75,10 @@ class TestPDT(JitTestCase):
make_global(NestedPDTInner, NestedModulePDTWrapper) make_global(NestedPDTInner, NestedModulePDTWrapper)
inner_pdt_model = NestedPDTInner() inner_pdt_model = NestedPDTInner()
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
inp: List[Tuple[Any, ...]] = [(20, ), (False, )] inp: List[Tuple[Any, ...]] = [(20,), (False,)]
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}) scripted_pdt_model = torch.jit.script(
wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
@ -87,10 +101,18 @@ class TestPDT(JitTestCase):
make_global(NestedModulePDTInner, NestedModulePDTOuter) make_global(NestedModulePDTInner, NestedModulePDTOuter)
inner_pdt_model = NestedModulePDTInner() inner_pdt_model = NestedModulePDTInner()
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ] inner_input: List[Tuple[Any, ...]] = [
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )] (10, 10),
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input, (1.9, 20),
outer_pdt_model: outer_input, }) ]
outer_input: List[Tuple[Any, ...]] = [(20,), (False,)]
scripted_pdt_model = torch.jit.script(
outer_pdt_model,
example_inputs={
inner_pdt_model: inner_input,
outer_pdt_model: outer_input,
},
)
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
@ -109,8 +131,10 @@ class TestPDT(JitTestCase):
make_global(NestedFunctionInForward) make_global(NestedFunctionInForward)
pdt_model = NestedFunctionInForward() pdt_model = NestedFunctionInForward()
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )] inp: List[Tuple[Any, ...]] = [(-1,), (False,)]
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp}) scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model: inp}
)
self.assertEqual(scripted_pdt_model(30), pdt_model(30)) self.assertEqual(scripted_pdt_model(30), pdt_model(30))
self.assertEqual(scripted_pdt_model(True), pdt_model(True)) self.assertEqual(scripted_pdt_model(True), pdt_model(True))
@ -126,14 +150,26 @@ class TestPDT(JitTestCase):
else: else:
return -1 return -1
make_global(TestModelWithExport) make_global(TestModelWithExport)
pdt_model = TestModelWithExport() pdt_model = TestModelWithExport()
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ] inp: List[Tuple[Any, ...]] = [
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp}) (
20,
10,
),
(
2.7,
8.9,
),
]
scripted_pdt_model = torch.jit.script(
pdt_model, example_inputs={pdt_model.fn: inp}
)
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)) self.assertTrue(
scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)
)
def test_class_methods(self): def test_class_methods(self):
class PDTModel: class PDTModel:
@ -142,10 +178,34 @@ class TestPDT(JitTestCase):
make_global(PDTModel) make_global(PDTModel)
pdt_model = PDTModel() pdt_model = PDTModel()
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ] inp: List[Tuple[Any, ...]] = [
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp}) (
[
10,
20,
],
),
]
scripted_pdt_model = torch.jit.script(
PDTModel, example_inputs={pdt_model.test_sum: inp}
)
script_model = scripted_pdt_model() script_model = scripted_pdt_model()
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], )) self.assertEqual(
script_model.test_sum(
[
10,
20,
30,
],
),
pdt_model.test_sum(
[
10,
20,
30,
],
),
)
def test_class_with_multiple_methods(self): def test_class_with_multiple_methods(self):
class PDTModelWithManyMethods: class PDTModelWithManyMethods:
@ -160,14 +220,64 @@ class TestPDT(JitTestCase):
make_global(PDTModelWithManyMethods) make_global(PDTModelWithManyMethods)
pdt_model = PDTModelWithManyMethods() pdt_model = PDTModelWithManyMethods()
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ] list_inp: List[Tuple[Any, ...]] = [
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ] (
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp, [
pdt_model.test_substring: str_inp}) 1.2,
2.3,
],
),
]
str_inp: List[Tuple[Any, ...]] = [
(
"abc",
"b",
),
]
scripted_pdt_model = torch.jit.script(
PDTModelWithManyMethods,
example_inputs={
pdt_model.test_list_to_dict: list_inp,
pdt_model.test_substring: str_inp,
},
)
script_model = scripted_pdt_model() script_model = scripted_pdt_model()
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], )) self.assertEqual(
self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", )) script_model.test_list_to_dict(
self.assertEqual(script_model.test_substring("helloworld", "def", ), pdt_model.test_substring("helloworld", "def", )) [
1.1,
2.2,
3.3,
],
),
pdt_model.test_list_to_dict(
[
1.1,
2.2,
3.3,
],
),
)
self.assertEqual(
script_model.test_substring(
"helloworld",
"world",
),
pdt_model.test_substring(
"helloworld",
"world",
),
)
self.assertEqual(
script_model.test_substring(
"helloworld",
"def",
),
pdt_model.test_substring(
"helloworld",
"def",
),
)
def test_multiple_class_with_same_method(self): def test_multiple_class_with_same_method(self):
class PDTModelOne: class PDTModelOne:
@ -181,16 +291,69 @@ class TestPDT(JitTestCase):
make_global(PDTModelOne, PDTModelTwo) make_global(PDTModelOne, PDTModelTwo)
pdt_model_one = PDTModelOne() pdt_model_one = PDTModelOne()
pdt_model_two = PDTModelTwo() pdt_model_two = PDTModelTwo()
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ] dict_inp: List[Tuple[Any, ...]] = [
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ] (
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}) {
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}) 1.2: True,
2.3: False,
},
1.2,
),
]
list_inp: List[Tuple[Any, ...]] = [
(
[
"abc",
"b",
],
"c",
),
]
scripted_pdt_model_one = torch.jit.script(
PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}
)
scripted_pdt_model_two = torch.jit.script(
PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}
)
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two() script_model_one, script_model_two = (
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4), scripted_pdt_model_one(),
pdt_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4)) scripted_pdt_model_two(),
self.assertEqual(script_model_two.test_find(["hello", "world", ], "world"), )
pdt_model_two.test_find(["hello", "world", ], "world")) self.assertEqual(
script_model_one.test_find(
{
1.1: True,
2.2: True,
3.3: False,
},
4.4,
),
pdt_model_one.test_find(
{
1.1: True,
2.2: True,
3.3: False,
},
4.4,
),
)
self.assertEqual(
script_model_two.test_find(
[
"hello",
"world",
],
"world",
),
pdt_model_two.test_find(
[
"hello",
"world",
],
"world",
),
)
def test_pdt(self): def test_pdt(self):
def test_sum(a, b): def test_sum(a, b):
@ -218,7 +381,9 @@ class TestPDT(JitTestCase):
return torch.complex(real, img) return torch.complex(real, img)
make_global(test_args_complex) make_global(test_args_complex)
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]) scripted_fn_complex = torch.jit.script(
test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]
)
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
@ -248,25 +413,49 @@ class TestPDT(JitTestCase):
make_global(test_list_and_tuple) make_global(test_list_and_tuple)
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)]) scripted_fn_float_list_input = torch.jit.script(
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])) test_list_and_tuple, example_inputs=[([4.9, 8.9],)]
)
self.assertEqual(
scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])
)
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)]) scripted_fn_bool_list_input = torch.jit.script(
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True])) test_list_and_tuple, example_inputs=[([True, False, True],)]
)
self.assertEqual(
scripted_fn_bool_list_input([True, True, True]),
test_list_and_tuple([True, True, True]),
)
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )]) scripted_fn_int_list_input = torch.jit.script(
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])) test_list_and_tuple, example_inputs=[([3, 4, 5],)]
)
self.assertEqual(
scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])
)
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)]) scripted_fn_float_tuple_input = torch.jit.script(
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))) test_list_and_tuple, example_inputs=[((4.9, 8.9),)]
)
self.assertEqual(
scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))
)
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple, scripted_fn_bool_tuple_input = torch.jit.script(
example_inputs=[((True, False, True),)]) test_list_and_tuple, example_inputs=[((True, False, True),)]
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)), )
test_list_and_tuple((True, True, True))) self.assertEqual(
scripted_fn_bool_tuple_input((True, True, True)),
test_list_and_tuple((True, True, True)),
)
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )]) scripted_fn_int_tuple_input = torch.jit.script(
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))) test_list_and_tuple, example_inputs=[((3, 4, 5),)]
)
self.assertEqual(
scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))
)
def test_nested_list_and_tuple(self): def test_nested_list_and_tuple(self):
def test_nested_list(inp): def test_nested_list(inp):
@ -282,43 +471,207 @@ class TestPDT(JitTestCase):
make_global(test_nested_list, test_nested_tuple) make_global(test_nested_list, test_nested_tuple)
list_inp = [[1, 2, 3, ], [5, 6, 7, ]] list_inp = [
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) [
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]] 1,
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) 2,
3,
],
[
5,
6,
7,
],
]
scripted_fn = torch.jit.script(
test_nested_list,
example_inputs=[
(list_inp,),
],
)
inp = [
[
0,
4,
7,
],
[
8,
11,
],
[
6,
-1,
-20,
],
]
self.assertEqual(
scripted_fn(
inp,
),
test_nested_list(
inp,
),
)
list_inp = ([1, 2, 3, ], [5, 6, 7, ]) list_inp = (
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ]) [
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ]) 1,
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, )) 2,
3,
],
[
5,
6,
7,
],
)
scripted_fn = torch.jit.script(
test_nested_list,
example_inputs=[
(list_inp,),
],
)
inp = (
[
0,
4,
7,
],
[
8,
11,
],
[
6,
-1,
-20,
],
)
self.assertEqual(
scripted_fn(
inp,
),
test_nested_list(
inp,
),
)
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )] tup_inp = [
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) (
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )] 1.0,
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) 2.6,
3.7,
),
(
5.7,
6.1,
1.7,
),
]
scripted_fn = torch.jit.script(
test_nested_tuple,
example_inputs=[
(tup_inp,),
],
)
inp = [
(
1.0,
4.1,
7.4,
),
(
4.8,
1.1,
-1.2,
),
(
6.3,
-1.3,
-2.0,
),
]
self.assertEqual(
scripted_fn(
inp,
),
test_nested_tuple(
inp,
),
)
tup_inp = ((True, False, True, ), (False, False, False, )) tup_inp = (
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ]) (
inp = ((True, True, True, ), (False, False, True, )) True,
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, )) False,
True,
),
(
False,
False,
False,
),
)
scripted_fn = torch.jit.script(
test_nested_tuple,
example_inputs=[
(tup_inp,),
],
)
inp = (
(
True,
True,
True,
),
(
False,
False,
True,
),
)
self.assertEqual(
scripted_fn(
inp,
),
test_nested_tuple(
inp,
),
)
def test_pdt_dict(self): def test_pdt_dict(self):
def test_dict(a): def test_dict(a):
return a['foo'] return a["foo"]
def test_dict_int_list(a): def test_dict_int_list(a):
return a[1] return a[1]
make_global(test_dict, test_dict_int_list) make_global(test_dict, test_dict_int_list)
str_bool_inp = {'foo' : True, 'bar': False} str_bool_inp = {"foo": True, "bar": False}
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, )) self.assertEqual(
scripted_fn(
{"foo": False, "bar": True},
),
test_dict(
{"foo": False, "bar": True},
),
)
str_list_inp = {0 : [True, False], 1: [False, True]} str_list_inp = {0: [True, False], 1: [False, True]}
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)]) scripted_fn = torch.jit.script(
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ), test_dict_int_list, example_inputs=[(str_list_inp,)]
test_dict_int_list({0 : [False, False], 1: [True, True]}, )) )
self.assertEqual(
scripted_fn(
{0: [False, False], 1: [True, True]},
),
test_dict_int_list(
{0: [False, False], 1: [True, True]},
),
)
def test_any(self): def test_any(self):
def test_multiple_types(a): def test_multiple_types(a):
@ -337,20 +690,36 @@ class TestPDT(JitTestCase):
make_global(test_multiple_types, test_multiple_type_refinement) make_global(test_multiple_types, test_multiple_type_refinement)
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )]) scripted_fn = torch.jit.script(
test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)]
)
self.assertEqual(scripted_fn(10), test_multiple_types(10)) self.assertEqual(scripted_fn(10), test_multiple_types(10))
self.assertEqual(scripted_fn("def"), test_multiple_types("def")) self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,), scripted_fn = torch.jit.script(
([3, 4, 5],), (True, ), ({"a": True}, ), ]) test_multiple_type_refinement,
example_inputs=[
(1,),
("abc",),
(8.9,),
([3, 4, 5],),
(True,),
({"a": True},),
],
)
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])) self.assertEqual(
scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])
)
self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False))
self.assertEqual(scripted_fn({"abc" : True, "def": False}), test_multiple_type_refinement({"abc" : True, "def": False})) self.assertEqual(
scripted_fn({"abc": True, "def": False}),
test_multiple_type_refinement({"abc": True, "def": False}),
)
def test_class_as_profiled_types(self): def test_class_as_profiled_types(self):
class UserDefinedClass: class UserDefinedClass:
@ -369,9 +738,33 @@ class TestPDT(JitTestCase):
make_global(UserDefinedClass, test_model) make_global(UserDefinedClass, test_model)
user_class = UserDefinedClass() user_class = UserDefinedClass()
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) scripted_fn = torch.jit.script(
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class)) test_model,
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class)) example_inputs=[
(
10,
user_class,
),
(
10.9,
user_class,
),
],
)
self.assertEqual(
scripted_fn(
100,
user_class,
),
test_model(100, user_class),
)
self.assertEqual(
scripted_fn(
1.9,
user_class,
),
test_model(1.9, user_class),
)
def test_class_with_args_as_profiled_types(self): def test_class_with_args_as_profiled_types(self):
class ClassWithArgs: class ClassWithArgs:
@ -391,8 +784,26 @@ class TestPDT(JitTestCase):
make_global(ClassWithArgs, test_model_with_args) make_global(ClassWithArgs, test_model_with_args)
user_class = ClassWithArgs(False) user_class = ClassWithArgs(False)
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ]) scripted_fn = torch.jit.script(
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True))) test_model_with_args,
example_inputs=[
(
10,
user_class,
),
(
10.9,
user_class,
),
],
)
self.assertEqual(
scripted_fn(
100,
ClassWithArgs(True),
),
test_model_with_args(100, ClassWithArgs(True)),
)
def test_nn_parameter_as_arg(self): def test_nn_parameter_as_arg(self):
class TestNNParameter(torch.nn.Module): class TestNNParameter(torch.nn.Module):
@ -408,7 +819,14 @@ class TestPDT(JitTestCase):
make_global(TestNNParameter) make_global(TestNNParameter)
pdt_model = TestNNParameter() pdt_model = TestNNParameter()
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], }) scripted_fn = torch.jit.script(
pdt_model,
example_inputs={
pdt_model: [
(10,),
],
},
)
self.assertEqual(scripted_fn(20), pdt_model(20)) self.assertEqual(scripted_fn(20), pdt_model(20))
def test_fx_tracing_with_typing(self): def test_fx_tracing_with_typing(self):
@ -422,7 +840,19 @@ class TestPDT(JitTestCase):
make_global(FXModel, FXModelOutput) make_global(FXModel, FXModelOutput)
pdt_model = FXModel() pdt_model = FXModel()
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) scripted_fn = torch.jit.script(
pdt_model,
example_inputs={
pdt_model: [
(
[
10,
20,
],
),
],
},
)
self.assertEqual(scripted_fn([20]), pdt_model([20])) self.assertEqual(scripted_fn([20]), pdt_model([20]))
def test_nonetype_as_optional_of_type(self): def test_nonetype_as_optional_of_type(self):
@ -434,11 +864,34 @@ class TestPDT(JitTestCase):
make_global(test_none) make_global(test_none)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )]) scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)])
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, )) self.assertEqual(
scripted_fn(
30.9,
),
test_none(
30.9,
),
)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )]) scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)])
self.assertEqual(scripted_fn(2, ), test_none(2, )) self.assertEqual(
scripted_fn(
2,
),
test_none(
2,
),
)
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )]) scripted_fn = torch.jit.script(
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), )) test_none, example_inputs=[(None,), (torch.Tensor(1),)]
)
self.assertEqual(
scripted_fn(
torch.ones(1),
),
test_none(
torch.ones(1),
),
)

View File

@ -1,17 +1,20 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch import unittest
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
from torch import nn
from torch.testing import FileCheck
from typing import Callable, List from typing import Callable, List
import unittest import torch
from torch import nn
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestPeephole(JitTestCase): class TestPeephole(JitTestCase):
def test_peephole_with_writes(self): def test_peephole_with_writes(self):
@ -62,11 +65,11 @@ class TestPeephole(JitTestCase):
tf = torch.jit.trace(f, (a, b)) tf = torch.jit.trace(f, (a, b))
FileCheck().check("type_as").run(str(tf.graph)) FileCheck().check("type_as").run(str(tf.graph))
self.run_pass('peephole', tf.graph) self.run_pass("peephole", tf.graph)
FileCheck().check_not("type_as").run(str(tf.graph)) FileCheck().check_not("type_as").run(str(tf.graph))
tf2 = torch.jit.trace(f, (a, c)) tf2 = torch.jit.trace(f, (a, c))
s = str(tf2.graph) s = str(tf2.graph)
self.run_pass('peephole', tf2.graph) self.run_pass("peephole", tf2.graph)
self.assertEqual(s, str(s)) self.assertEqual(s, str(s))
def test_peephole_dynamic(self): def test_peephole_dynamic(self):
@ -83,7 +86,7 @@ class TestPeephole(JitTestCase):
def foo(x, y, z): def foo(x, y, z):
return len([x, y, z]) return len([x, y, z])
self.run_pass('peephole', foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check("value=3").check_next("return").run(foo.graph) FileCheck().check("value=3").check_next("return").run(foo.graph)
@torch.jit.script @torch.jit.script
@ -93,7 +96,7 @@ class TestPeephole(JitTestCase):
li.append(x) li.append(x)
return len([x, y, z]) return len([x, y, z])
self.run_pass('peephole', foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::len").run(foo.graph) FileCheck().check_not("aten::len").run(foo.graph)
@torch.jit.script @torch.jit.script
@ -102,7 +105,7 @@ class TestPeephole(JitTestCase):
return li[1], li[-2] return li[1], li[-2]
FileCheck().check("aten::__getitem__").run(foo.graph) FileCheck().check("aten::__getitem__").run(foo.graph)
self.run_pass('peephole', foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check_not("aten::__getitem__").run(foo.graph) FileCheck().check_not("aten::__getitem__").run(foo.graph)
@torch.jit.script @torch.jit.script
@ -110,7 +113,7 @@ class TestPeephole(JitTestCase):
li = [x, y, z] li = [x, y, z]
return li[-7] return li[-7]
self.run_pass('peephole', foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph) FileCheck().check("aten::__getitem__").run(foo.graph)
@torch.jit.script @torch.jit.script
@ -120,25 +123,25 @@ class TestPeephole(JitTestCase):
li.append(x) li.append(x)
return li[-2] return li[-2]
self.run_pass('peephole', foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check("aten::__getitem__").run(foo.graph) FileCheck().check("aten::__getitem__").run(foo.graph)
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
def test_peephole_cuda(self): def test_peephole_cuda(self):
a = torch.tensor([0.4], device='cpu') a = torch.tensor([0.4], device="cpu")
b = torch.tensor([0.7], device='cuda') b = torch.tensor([0.7], device="cuda")
c = torch.tensor([0.7], device='cuda') c = torch.tensor([0.7], device="cuda")
def f(x, y): def f(x, y):
return x.type_as(y) return x.type_as(y)
trace = torch.jit.trace(f, (a, c)) trace = torch.jit.trace(f, (a, c))
s = str(trace.graph) s = str(trace.graph)
self.run_pass('peephole', trace.graph) self.run_pass("peephole", trace.graph)
self.assertEqual(s, str(trace.graph)) self.assertEqual(s, str(trace.graph))
trace = torch.jit.trace(f, (b, c)) trace = torch.jit.trace(f, (b, c))
self.run_pass('peephole', trace.graph) self.run_pass("peephole", trace.graph)
self.run_pass('dce', trace.graph) self.run_pass("dce", trace.graph)
FileCheck().check_not("type_as").run(str(trace.graph)) FileCheck().check_not("type_as").run(str(trace.graph))
@_inline_everything @_inline_everything
@ -152,7 +155,7 @@ class TestPeephole(JitTestCase):
return refine(torch.tensor(4)) return refine(torch.tensor(4))
FileCheck().check("prim::unchecked_cast").run(test.graph) FileCheck().check("prim::unchecked_cast").run(test.graph)
self.run_pass('peephole', test.graph) self.run_pass("peephole", test.graph)
FileCheck().check_not("prim::unchecked_cast").run(test.graph) FileCheck().check_not("prim::unchecked_cast").run(test.graph)
# refinement not optimzied out # refinement not optimzied out
@ -166,7 +169,7 @@ class TestPeephole(JitTestCase):
self.checkScript(is_int_tensor, (torch.tensor(2),)) self.checkScript(is_int_tensor, (torch.tensor(2),))
self.checkScript(is_int_tensor, (torch.tensor(2.5),)) self.checkScript(is_int_tensor, (torch.tensor(2.5),))
graph = torch.jit.script(is_int_tensor).graph graph = torch.jit.script(is_int_tensor).graph
self.run_pass('peephole', graph) self.run_pass("peephole", graph)
FileCheck().check("prim::unchecked_cast").run(graph) FileCheck().check("prim::unchecked_cast").run(graph)
def test_short_circuit_optimization(self): def test_short_circuit_optimization(self):
@ -174,8 +177,11 @@ class TestPeephole(JitTestCase):
def const_expressions(x): def const_expressions(x):
# type: (int) -> Tuple[bool, bool] # type: (int) -> Tuple[bool, bool]
return x == 1 and False, x == 1 or True return x == 1 and False, x == 1 or True
self.run_pass('constant_propagation', const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph) self.run_pass("constant_propagation", const_expressions.graph)
FileCheck().check_not("prim::If").check_not("aten::eq").run(
const_expressions.graph
)
self.assertEqual(const_expressions(1), (False, True)) self.assertEqual(const_expressions(1), (False, True))
@torch.jit.script @torch.jit.script
@ -183,15 +189,18 @@ class TestPeephole(JitTestCase):
# type: (int) -> Tuple[bool, bool] # type: (int) -> Tuple[bool, bool]
return x == 1 and True, x == 1 or False return x == 1 and True, x == 1 or False
self.run_pass('peephole', redundant_expressions.graph) self.run_pass("peephole", redundant_expressions.graph)
self.assertEqual(redundant_expressions(1), (True, True)) self.assertEqual(redundant_expressions(1), (True, True))
self.assertEqual(redundant_expressions(0), (False, False)) self.assertEqual(redundant_expressions(0), (False, False))
# and True / or False are removed from graph # and True / or False are removed from graph
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph) FileCheck().check("aten::eq").check_not("prim::If").run(
redundant_expressions.graph
)
def test_conv_dim_folding(self): def test_conv_dim_folding(self):
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
for mod in modules: for mod in modules:
class ConvDim(torch.nn.Module): class ConvDim(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -233,7 +242,6 @@ class TestPeephole(JitTestCase):
FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph) FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph)
FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph) FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph)
def test_normalized_is_op(self): def test_normalized_is_op(self):
def convertible_is_op(x: bool, y: bool): def convertible_is_op(x: bool, y: bool):
return x is True, False is x, x is y return x is True, False is x, x is y
@ -558,7 +566,7 @@ class TestPeephole(JitTestCase):
def foo4(): def foo4():
x = torch.zeros([2, 2]) x = torch.zeros([2, 2])
return x + 0. return x + 0.0
funcs = foo1, foo2, foo3, foo4 funcs = foo1, foo2, foo3, foo4
inps = (torch.ones([2]),), (), (), () inps = (torch.ones([2]),), (), (), ()
@ -582,7 +590,7 @@ class TestPeephole(JitTestCase):
self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2]))) self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2])))
def func(x): def func(x):
return (x + 0.) - 5 return (x + 0.0) - 5
func_s = torch.jit.script(func) func_s = torch.jit.script(func)
inp = next(func_s.graph.inputs()) inp = next(func_s.graph.inputs())
@ -640,6 +648,7 @@ class TestPeephole(JitTestCase):
return z return z
else: else:
return z2 return z2
out = next(foo.graph.findNode("prim::If").outputs()) out = next(foo.graph.findNode("prim::If").outputs())
out.setType(torch._C.OptionalType(torch._C.IntType.get())) out.setType(torch._C.OptionalType(torch._C.IntType.get()))
self.run_pass("peephole", foo.graph) self.run_pass("peephole", foo.graph)
@ -665,12 +674,13 @@ class TestPeephole(JitTestCase):
_6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1 _6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1
return [_5, int(_6)] return [_5, int(_6)]
FileCheck().check("aten::add").check("aten::sub") \ FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check(
.check("aten::mul").check("aten::floordiv") \ "aten::floordiv"
.check("aten::div").run(foo.graph) ).check("aten::div").run(foo.graph)
self.run_pass("peephole", foo.graph) self.run_pass("peephole", foo.graph)
FileCheck().check("graph").check("):") \ FileCheck().check("graph").check("):").check_next("ListConstruct").check_next(
.check_next("ListConstruct").check_next("return").run(foo.graph) "return"
).run(foo.graph)
self.assertEqual(foo(0, 1, 2, 3), [1, 3]) self.assertEqual(foo(0, 1, 2, 3), [1, 3])
def test_peephole_dict_getitem_simple(self): def test_peephole_dict_getitem_simple(self):
@ -687,9 +697,9 @@ class TestPeephole(JitTestCase):
@torch.jit.script @torch.jit.script
def foo(a: int, b: int): def foo(a: int, b: int):
d = {'0': a, '1': b} d = {"0": a, "1": b}
x = d['1'] x = d["1"]
y = d['0'] y = d["0"]
return x, y return x, y
self.run_pass("peephole", foo.graph) self.run_pass("peephole", foo.graph)
@ -815,14 +825,14 @@ class TestPeephole(JitTestCase):
graph = torch.jit.script(foo).graph graph = torch.jit.script(foo).graph
self.run_pass("peephole", graph) self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph) FileCheck().check_not("aten::slice").run(graph)
self.checkScript(foo, (3, )) self.checkScript(foo, (3,))
def test_peephole_slice_one_empty_arg(self): def test_peephole_slice_one_empty_arg(self):
def check_helper(fn: Callable[[int], None]) -> None: def check_helper(fn: Callable[[int], None]) -> None:
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass("peephole", graph) self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph) FileCheck().check_not("aten::slice").run(graph)
self.checkScript(fn, (3, )) self.checkScript(fn, (3,))
def foo(x: int): def foo(x: int):
return [1, 2, x, 4, 5, 6, 7][1::2] return [1, 2, x, 4, 5, 6, 7][1::2]
@ -844,7 +854,7 @@ class TestPeephole(JitTestCase):
graph = torch.jit.script(fn).graph graph = torch.jit.script(fn).graph
self.run_pass("peephole", graph) self.run_pass("peephole", graph)
FileCheck().check_not("aten::slice").run(graph) FileCheck().check_not("aten::slice").run(graph)
self.checkScript(fn, (3, )) self.checkScript(fn, (3,))
def foo(x: int): def foo(x: int):
return [1, 2, x, 4, 5, 6, 7][::2] return [1, 2, x, 4, 5, 6, 7][::2]

View File

@ -9,12 +9,15 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, warmup_backward, FileCheck from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
@skipIfTorchDynamo() @skipIfTorchDynamo()
class TestProfiler(JitTestCase): class TestProfiler(JitTestCase):
@ -58,8 +61,9 @@ class TestProfiler(JitTestCase):
# item & add should not get pulled into the fusion group - # item & add should not get pulled into the fusion group -
# we expect to see Fusion Group (item / add) Fusion Group in ir dump # we expect to see Fusion Group (item / add) Fusion Group in ir dump
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next("Tensor = aten::add").check("TensorExpr").run(g) FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
"Tensor = aten::add"
).check("TensorExpr").run(g)
@torch.jit.script @torch.jit.script
def non_const_dtype(x, y, cond: bool): def non_const_dtype(x, y, cond: bool):
@ -70,7 +74,9 @@ class TestProfiler(JitTestCase):
non_const_dtype(x, x, True) non_const_dtype(x, x, True)
g = torch.jit.last_executed_optimized_graph() g = torch.jit.last_executed_optimized_graph()
# because dtype is non-const, sum should not get pulled into the Fusion Group # because dtype is non-const, sum should not get pulled into the Fusion Group
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(g) FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
g
)
def test_specialize_backward(self): def test_specialize_backward(self):
def test_fuse(a, b): def test_fuse(a, b):
@ -118,13 +124,15 @@ class TestProfiler(JitTestCase):
d = c * b d = c * b
return d return d
x = torch.tensor([.5]) x = torch.tensor([0.5])
for _ in range(3): for _ in range(3):
test_fuse(x, x) test_fuse(x, x)
g = torch.jit.last_executed_optimized_graph() g = torch.jit.last_executed_optimized_graph()
# Types should remain specialized for typecheck outputs & fusion outputs # Types should remain specialized for typecheck outputs & fusion outputs
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g) FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
"\n"
).check("Double").check_same("TensorExpr").run(g)
# other outputs should not be specialized # other outputs should not be specialized
FileCheck().check("Tensor = prim::If").run(g) FileCheck().check("Tensor = prim::If").run(g)
@ -201,7 +209,9 @@ class TestProfiler(JitTestCase):
foo(x, y) foo(x, y)
foo(x, y) foo(x, y)
g = torch.jit.last_executed_optimized_graph() g = torch.jit.last_executed_optimized_graph()
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(g) FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
g
)
def test_autograd_fallback_graph(self): def test_autograd_fallback_graph(self):
@torch.jit.script @torch.jit.script

View File

@ -1,13 +1,13 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import os import os
import random
import sys import sys
import tempfile import tempfile
import random
from textwrap import dedent from textwrap import dedent
import torch import torch
from torch.testing._internal.jit_utils import JitTestCase, execWrapper from torch.testing._internal.jit_utils import execWrapper, JitTestCase
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -20,14 +20,17 @@ if __name__ == "__main__":
"instead." "instead."
) )
def get_fn(file_name, script_path): def get_fn(file_name, script_path):
import importlib.util import importlib.util
spec = importlib.util.spec_from_file_location(file_name, script_path) spec = importlib.util.spec_from_file_location(file_name, script_path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
fn = module.fn fn = module.fn
return fn return fn
class TestPythonBuiltinOP(JitTestCase): class TestPythonBuiltinOP(JitTestCase):
def test_add(self): def test_add(self):
def func(a, b): def func(a, b):
@ -48,16 +51,18 @@ class TestPythonBuiltinOP(JitTestCase):
self.checkScript(func, (a, b), optimize=True) self.checkScript(func, (a, b), optimize=True)
def test_matmul_py3(self): def test_matmul_py3(self):
code = dedent(""" code = dedent(
"""
def fn(a, b): def fn(a, b):
return a @ b return a @ b
""") """
)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
script_path = os.path.join(tmp_dir, 'script.py') script_path = os.path.join(tmp_dir, "script.py")
with open(script_path, 'w') as f: with open(script_path, "w") as f:
f.write(code) f.write(code)
fn = get_fn('test_matmul_py3', script_path) fn = get_fn("test_matmul_py3", script_path)
a = torch.rand(4, 3, requires_grad=True) a = torch.rand(4, 3, requires_grad=True)
b = torch.rand(3, 2, requires_grad=True) b = torch.rand(3, 2, requires_grad=True)
@ -65,18 +70,18 @@ class TestPythonBuiltinOP(JitTestCase):
def test_pow(self): def test_pow(self):
def func(a, b): def func(a, b):
return a ** b return a**b
def func2(a, b, c, d): def func2(a, b, c, d):
return c + a ** b ** d return c + a**b**d
def func3(a, b): def func3(a, b):
# type: (int, float) -> float # type: (int, float) -> float
return a ** b return a**b
def func4(): def func4():
# type: () -> float # type: () -> float
return 2 ** -2 return 2**-2
def func5(x, y): def func5(x, y):
return x.item() ** y.item() return x.item() ** y.item()
@ -90,7 +95,12 @@ class TestPythonBuiltinOP(JitTestCase):
self.checkScript(func3, (4, -0.5), optimize=True) self.checkScript(func3, (4, -0.5), optimize=True)
self.checkScript(func4, ()) self.checkScript(func4, ())
inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)] inputs = [
torch.tensor(2),
torch.tensor(-2),
torch.tensor(0.5),
torch.tensor(0.2),
]
for x in inputs: for x in inputs:
for y in inputs: for y in inputs:
if x < 0: if x < 0:
@ -100,7 +110,7 @@ class TestPythonBuiltinOP(JitTestCase):
def test_triple(self): def test_triple(self):
def func(x): def func(x):
return 3. * x return 3.0 * x
x = torch.rand(1, dtype=torch.float, requires_grad=True) x = torch.rand(1, dtype=torch.float, requires_grad=True)
self.checkScript(func, [x], optimize=True) self.checkScript(func, [x], optimize=True)
@ -154,22 +164,36 @@ class TestPythonBuiltinOP(JitTestCase):
def test_stepped_tuple_slicing(self): def test_stepped_tuple_slicing(self):
def check_slicing_tuple(slicing, tuple_type, tuple): def check_slicing_tuple(slicing, tuple_type, tuple):
template = dedent(""" template = dedent(
"""
def func(x): def func(x):
# type: ({}) -> Any # type: ({}) -> Any
return x{} return x{}
""") """
)
self._check_code(template.format(tuple_type, slicing), "func", [tuple]) self._check_code(template.format(tuple_type, slicing), "func", [tuple])
check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2)) check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2))
check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) check_slicing_tuple(
check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) "[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)) )
check_slicing_tuple(
"[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
)
check_slicing_tuple(
"[5:7:-2]",
"Tuple[int, int, int, int, int, int, int]",
(0, 1, 2, 3, 4, 5, 6),
)
check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)) check_slicing_tuple(
check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)) "[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)
)
check_slicing_tuple(
"[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)
)
def test_index(self): def test_index(self):
def consec(size, start=0): def consec(size, start=0):
@ -177,10 +201,12 @@ class TestPythonBuiltinOP(JitTestCase):
return torch.arange(numel).view(size) return torch.arange(numel).view(size)
def check_indexing(indexing, tensor): def check_indexing(indexing, tensor):
template = dedent(""" template = dedent(
"""
def func(x): def func(x):
return x{} return x{}
""") """
)
self._check_code(template.format(indexing), "func", [tensor]) self._check_code(template.format(indexing), "func", [tensor])
@ -188,62 +214,66 @@ class TestPythonBuiltinOP(JitTestCase):
value1 = torch.tensor(value1) value1 = torch.tensor(value1)
value2 = torch.tensor(value2) value2 = torch.tensor(value2)
template = dedent(""" template = dedent(
"""
def func(x, value1, value2): def func(x, value1, value2):
i = int(value1) i = int(value1)
j = int(value2) j = int(value2)
return x{} return x{}
""") """
)
self._check_code(template.format(indexing), "func", [tensor, value1, value2]) self._check_code(
template.format(indexing), "func", [tensor, value1, value2]
)
# basic slices # basic slices
check_indexing('[0]', consec((3, 3))) check_indexing("[0]", consec((3, 3)))
check_indexing('[1]', consec((3, 3), 10)) check_indexing("[1]", consec((3, 3), 10))
check_indexing('[2]', consec((3, 3), 19)) check_indexing("[2]", consec((3, 3), 19))
check_indexing('[2]', consec((3,))) check_indexing("[2]", consec((3,)))
check_indexing('[-1]', consec((3, 3), 19)) check_indexing("[-1]", consec((3, 3), 19))
check_indexing('[0:2]', consec((3, 3, 3))) check_indexing("[0:2]", consec((3, 3, 3)))
check_indexing('[1:-1]', consec((3, 3, 3))) check_indexing("[1:-1]", consec((3, 3, 3)))
check_indexing('[-3:-1]', consec((6, 3))) check_indexing("[-3:-1]", consec((6, 3)))
check_indexing('[1:]', consec((3, 3))) check_indexing("[1:]", consec((3, 3)))
check_indexing('[:1]', consec((3, 3))) check_indexing("[:1]", consec((3, 3)))
check_indexing('[:]', consec((3, 2))) check_indexing("[:]", consec((3, 2)))
# multi-dim: indexes # multi-dim: indexes
check_indexing('[0, 1]', consec((3, 3))) check_indexing("[0, 1]", consec((3, 3)))
check_indexing('[0, 1]', consec((3, 3, 2))) check_indexing("[0, 1]", consec((3, 3, 2)))
check_indexing('[1, 0, 2]', consec((3, 3, 3))) check_indexing("[1, 0, 2]", consec((3, 3, 3)))
check_indexing('[2, -1]', consec((3, 3))) check_indexing("[2, -1]", consec((3, 3)))
# multi-dim: mixed slicing and indexing # multi-dim: mixed slicing and indexing
check_indexing('[0, 1:2]', consec((3, 3))) check_indexing("[0, 1:2]", consec((3, 3)))
check_indexing('[0, :1]', consec((3, 3, 2))) check_indexing("[0, :1]", consec((3, 3, 2)))
check_indexing('[1, 2:]', consec((3, 3, 3))) check_indexing("[1, 2:]", consec((3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3))) check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3)))
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3))) check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3)))
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3))) check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3)))
# zero-sized slices # zero-sized slices
check_indexing('[0:0]', consec((2, 2))) check_indexing("[0:0]", consec((2, 2)))
check_indexing('[0:0, 1]', consec((3, 3))) check_indexing("[0:0, 1]", consec((3, 3)))
# trivial expression usage # trivial expression usage
check_indexing('[1+1]', consec((3, 3))) check_indexing("[1+1]", consec((3, 3)))
check_indexing('[1:(0 + 2)]', consec((3, 3, 3))) check_indexing("[1:(0 + 2)]", consec((3, 3, 3)))
# None for new dimensions # None for new dimensions
check_indexing('[None, 0]', consec((3, 3))) check_indexing("[None, 0]", consec((3, 3)))
check_indexing('[1, None]', consec((3, 3), 10)) check_indexing("[1, None]", consec((3, 3), 10))
check_indexing('[None, None, 2]', consec((3, 3), 19)) check_indexing("[None, None, 2]", consec((3, 3), 19))
check_indexing('[None, 2, None]', consec((3,))) check_indexing("[None, 2, None]", consec((3,)))
check_indexing('[0:2, None]', consec((3, 3, 3))) check_indexing("[0:2, None]", consec((3, 3, 3)))
check_indexing('[None, 1:-1]', consec((3, 3, 3))) check_indexing("[None, 1:-1]", consec((3, 3, 3)))
check_indexing('[None, -3:-1, None]', consec((6, 3))) check_indexing("[None, -3:-1, None]", consec((6, 3)))
check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3))) check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3)))
check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3))) check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3)))
# dynamic expression usage # dynamic expression usage
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1) check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
@ -257,10 +287,12 @@ class TestPythonBuiltinOP(JitTestCase):
def check_indexing(indexing, tensor, **kwargs): def check_indexing(indexing, tensor, **kwargs):
indices_dict = kwargs indices_dict = kwargs
template = dedent(""" template = dedent(
"""
def func(x{formals}): def func(x{formals}):
return x{expr} return x{expr}
""") """
)
formals = [] formals = []
values = [] values = []
@ -268,17 +300,18 @@ class TestPythonBuiltinOP(JitTestCase):
formals.append(formal) formals.append(formal)
values.append(value) values.append(value)
formals = ''.join(map(', {}'.format, formals)) formals = "".join(map(", {}".format, formals))
inputs = [tensor] + values inputs = [tensor] + values
self._check_code(template.format(formals=formals, expr=indexing), self._check_code(
"func", inputs) template.format(formals=formals, expr=indexing), "func", inputs
)
# Indexing with tensor (basic) # Indexing with tensor (basic)
check_indexing('[i]', consec((3, 3)), i=torch.tensor([0])) check_indexing("[i]", consec((3, 3)), i=torch.tensor([0]))
check_indexing('[i]', consec((3, 3)), i=torch.tensor(1)) check_indexing("[i]", consec((3, 3)), i=torch.tensor(1))
check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2])) check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2]))
check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0])) check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0]))
check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1])) check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
# NB: indexing with tensors and indexing with sequences can be implemented # NB: indexing with tensors and indexing with sequences can be implemented
# in a very similar way (sequences are converted to tensors), so only one # in a very similar way (sequences are converted to tensors), so only one
@ -290,49 +323,49 @@ class TestPythonBuiltinOP(JitTestCase):
inp = consec((4, 8, 5)) inp = consec((4, 8, 5))
to_check = [ to_check = [
# [[0, 1, 3]] # [[0, 1, 3]]
['[i]', {'i': [0, 1, 3]}], ["[i]", {"i": [0, 1, 3]}],
# [[0, 2], [1, 3]] # [[0, 2], [1, 3]]
['[i, j]', {'i': [0, 2], 'j': [1, 3]}], ["[i, j]", {"i": [0, 2], "j": [1, 3]}],
# [[[0, 1], [0, 1]], [[0, 1], [0, 1]]] # [[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
['[i, j]', {'i': [[0, 1], [0, 1]], 'j': [[0, 1], [0, 1]]}], ["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}],
# [[0, 2], [1, 3], [1, 1]] # [[0, 2], [1, 3], [1, 1]]
['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}], ["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}],
# [[0, 2], 1, [1, 1]] # [[0, 2], 1, [1, 1]]
['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}], ["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}],
# [:, :, [0, 3, 4]] # [:, :, [0, 3, 4]]
['[:, :, i]', {'i': [0, 3, 4]}], ["[:, :, i]", {"i": [0, 3, 4]}],
# [:, [2, 4, 5, 7], 2:4] # [:, [2, 4, 5, 7], 2:4]
['[:, i, 2:4]', {'i': [0, 2, 3]}], ["[:, i, 2:4]", {"i": [0, 2, 3]}],
# [[2, 3], :, :] # [[2, 3], :, :]
['[i, :, :]', {'i': [2, 3]}], ["[i, :, :]", {"i": [2, 3]}],
# [:, [0, 2, 3], [1, 3, 4]] # [:, [0, 2, 3], [1, 3, 4]]
['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}], ["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
# [:, [0], [1, 2, 4]] # [:, [0], [1, 2, 4]]
['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}], ["[:, i, j]", {"i": [0], "j": [1, 2, 4]}],
# [:, [0, 1, 3], [4]] # [:, [0, 1, 3], [4]]
['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}], ["[:, i, j]", {"i": [0, 1, 3], "j": [4]}],
# [:, [[0, 1], [1, 0]], [[2, 3]]] # [:, [[0, 1], [1, 0]], [[2, 3]]]
['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}], ["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
# [:, [[0, 1], [2, 3]], [[0]]] # [:, [[0, 1], [2, 3]], [[0]]]
['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}], ["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
# [:, [[5, 6]], [[0, 3], [4, 4]]] # [:, [[5, 6]], [[0, 3], [4, 4]]]
['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}], ["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}],
# [[0, 2, 3], [1, 3, 4], :] # [[0, 2, 3], [1, 3, 4], :]
['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}], ["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
# [0, [1, 2, 4], :] # [0, [1, 2, 4], :]
['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}], ["[i, j, :]", {"i": 0, "j": [1, 2, 4]}],
# [[0, 1, 3], 4, :] # [[0, 1, 3], 4, :]
['[i, j, :]', {'i': [0, 1, 3], 'j': 4}], ["[i, j, :]", {"i": [0, 1, 3], "j": 4}],
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :] # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}], ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}],
# [[[0, 1], [1, 0]], [[2, 3]], :] # [[[0, 1], [1, 0]], [[2, 3]], :]
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}], ["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
# [[[0, 1], [2, 3]], [[0]], :] # [[[0, 1], [2, 3]], [[0]], :]
['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}], ["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
# [[[2, 1]], [[0, 3], [4, 4]], :] # [[[2, 1]], [[0, 3], [4, 4]], :]
['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}], ["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}],
# [[[2]], [[0, 3], [4, 1]], 0:2] # [[[2]], [[0, 3], [4, 1]], 0:2]
['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}], ["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}],
] ]
for expr, argdict in to_check: for expr, argdict in to_check:
@ -372,29 +405,35 @@ class TestPythonBuiltinOP(JitTestCase):
for _ in range(100): for _ in range(100):
indices = [random.choice(vals) for _ in range(4)] indices = [random.choice(vals) for _ in range(4)]
indices[random.randint(0, len(indices) - 1)] = "..." indices[random.randint(0, len(indices) - 1)] = "..."
test_str = dedent(""" test_str = dedent(
"""
def f(): def f():
x = torch.ones(10, 9, 8, 7, 6) x = torch.ones(10, 9, 8, 7, 6)
return x{indices}.shape return x{indices}.shape
""".format(indices=indices)) """.format(
test_str = test_str.replace(r"'", r'') indices=indices
)
)
test_str = test_str.replace(r"'", r"")
scope = {} scope = {}
execWrapper(test_str, globals(), scope) execWrapper(test_str, globals(), scope)
cu = torch.jit.CompilationUnit(test_str) cu = torch.jit.CompilationUnit(test_str)
res1 = cu.f() res1 = cu.f()
res2 = scope['f']() res2 = scope["f"]()
self.assertEqual(res1, res2) self.assertEqual(res1, res2)
def test_inf(self): def test_inf(self):
@torch.jit.script @torch.jit.script
def foo(a): def foo(a):
return a < float('inf') return a < float("inf")
s = torch.rand(1) s = torch.rand(1)
self.assertTrue(foo(s)) self.assertTrue(foo(s))
@torch.jit.script @torch.jit.script
def bar(a): def bar(a):
return a > float('-inf') return a > float("-inf")
s = torch.rand(1) s = torch.rand(1)
self.assertTrue(foo(s)) self.assertTrue(foo(s))
@ -414,19 +453,22 @@ class TestPythonBuiltinOP(JitTestCase):
def test_str_to_float(self): def test_str_to_float(self):
@torch.jit.script @torch.jit.script
def foo(a): def foo(a):
return 0.5 == float('0.5 hello') return 0.5 == float("0.5 hello")
s = torch.rand(1) s = torch.rand(1)
with self.assertRaisesRegex(RuntimeError, "could not convert string to float"): with self.assertRaisesRegex(RuntimeError, "could not convert string to float"):
self.assertTrue(foo(s)) self.assertTrue(foo(s))
@torch.jit.script @torch.jit.script
def foo(a): def foo(a):
return 0.5 == float('0.5') return 0.5 == float("0.5")
s = torch.rand(1) s = torch.rand(1)
self.assertTrue(foo(s)) self.assertTrue(foo(s))
@torch.jit.script @torch.jit.script
def foo(a): def foo(a):
return 0. == float('0') return 0.0 == float("0")
s = torch.rand(1) s = torch.rand(1)
self.assertTrue(foo(s)) self.assertTrue(foo(s))

View File

@ -1,22 +1,26 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import torch
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import IS_MACOS
import numpy as np
import unittest import unittest
if __name__ == '__main__': import numpy as np
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" import torch
"\tpython test/test_jit.py TESTNAME\n\n" from torch.testing import FileCheck
"instead.") from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestPythonIr(JitTestCase): class TestPythonIr(JitTestCase):
def test_param_strides(self): def test_param_strides(self):
def trace_me(arg): def trace_me(arg):
return arg return arg
t = torch.zeros(1, 3, 16, 16) t = torch.zeros(1, 3, 16, 16)
traced = torch.jit.trace(trace_me, t) traced = torch.jit.trace(trace_me, t)
value = list(traced.graph.param_node().outputs())[0] value = list(traced.graph.param_node().outputs())[0]
@ -78,8 +82,12 @@ class TestPythonIr(JitTestCase):
g = foo.graph g = foo.graph
muls = g.findAllNodes("aten::mul") muls = g.findAllNodes("aten::mul")
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls) scalar_muls = filter(
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls) lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
)
mul_constant_int = filter(
lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
)
for mul in mul_constant_int: for mul in mul_constant_int:
with g.insert_point_guard(mul): with g.insert_point_guard(mul):
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs())) outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))

View File

@ -5,25 +5,31 @@ import re
import sys import sys
import types import types
import typing import typing
import typing_extensions from collections import OrderedDict
from typing import List, Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.jit.frontend import torch.jit.frontend
import torch.nn as nn import torch.nn as nn
import typing_extensions
from torch import Tensor from torch import Tensor
from torch.testing import FileCheck from torch.testing import FileCheck
from collections import OrderedDict
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything from torch.testing._internal.jit_utils import (
_tmp_donotuse_dont_inline_everything,
JitTestCase,
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRecursiveScript(JitTestCase): class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self): def test_inferred_nonetype(self):
@ -87,7 +93,9 @@ class TestRecursiveScript(JitTestCase):
return self.fn(x) return self.fn(x)
m = M(fn) m = M(fn)
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to compile", "i_dont_exist"): with self.assertRaisesRegexWithHighlight(
RuntimeError, "failed to compile", "i_dont_exist"
):
torch.jit.script(m) torch.jit.script(m)
def test_init_error(self): def test_init_error(self):
@ -119,12 +127,12 @@ class TestRecursiveScript(JitTestCase):
# sm1 was created while m had training = True # sm1 was created while m had training = True
self.assertTrue(sm1.training) self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c.getattr('training')) self.assertEqual(sm1.training, sm1._c.getattr("training"))
self.assertEqual(sm1(), 2) self.assertEqual(sm1(), 2)
# sm2 was created after m was eval'ed # sm2 was created after m was eval'ed
self.assertFalse(sm2.training) self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c.getattr('training')) self.assertEqual(sm2.training, sm2._c.getattr("training"))
self.assertEqual(sm2(), 0) self.assertEqual(sm2(), 0)
def test_module_name(self): def test_module_name(self):
@ -165,7 +173,7 @@ class TestRecursiveScript(JitTestCase):
def test_constants_with_final(self): def test_constants_with_final(self):
class M1(torch.nn.Module): class M1(torch.nn.Module):
x : torch.jit.Final[int] x: torch.jit.Final[int]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -177,7 +185,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(M1(), (torch.randn(2, 2),)) self.checkModule(M1(), (torch.randn(2, 2),))
class M2(torch.nn.Module): class M2(torch.nn.Module):
x : typing_extensions.Final[int] x: typing_extensions.Final[int]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -189,7 +197,7 @@ class TestRecursiveScript(JitTestCase):
self.checkModule(M2(), (torch.randn(2, 2),)) self.checkModule(M2(), (torch.randn(2, 2),))
class M3(torch.nn.Module): class M3(torch.nn.Module):
x : typing.Final[int] x: typing.Final[int]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -206,12 +214,15 @@ class TestRecursiveScript(JitTestCase):
def unscriptable(self): def unscriptable(self):
return "a" + 200 return "a" + 200
class TestModule(torch.nn.Module): class TestModule(torch.nn.Module):
def forward(self, x): def forward(self, x):
return MyScriptClass() return MyScriptClass()
with self.assertRaisesRegexWithHighlight(torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass"): with self.assertRaisesRegexWithHighlight(
torch.jit.frontend.FrontendError,
"Cannot instantiate class",
"MyScriptClass",
):
t = torch.jit.script(TestModule()) t = torch.jit.script(TestModule())
def test_method_call(self): def test_method_call(self):
@ -246,13 +257,13 @@ class TestRecursiveScript(JitTestCase):
print(m) print(m)
f = FileCheck() f = FileCheck()
f.check('MyModule') f.check("MyModule")
f.check('Conv2d') f.check("Conv2d")
f.check('Linear') f.check("Linear")
f.check('Submodule') f.check("Submodule")
f.run(out[0]) f.run(out[0])
self.assertEqual(m.original_name, 'MyModule') self.assertEqual(m.original_name, "MyModule")
def test_dir(self): def test_dir(self):
def test_module_dir(mod): def test_module_dir(mod):
@ -260,8 +271,17 @@ class TestRecursiveScript(JitTestCase):
scripted_mod = torch.jit.script(mod) scripted_mod = torch.jit.script(mod)
dir_scripted = set(dir(scripted_mod)) dir_scripted = set(dir(scripted_mod))
# set not currently copied over # set not currently copied over
ignore_set = ["training", "__delitem__", "__setitem__", "clear", "items", ignore_set = [
"keys", "pop", "update", "values"] "training",
"__delitem__",
"__setitem__",
"clear",
"items",
"keys",
"pop",
"update",
"values",
]
for attr in dir_set: for attr in dir_set:
if attr in ignore_set: if attr in ignore_set:
continue continue
@ -283,7 +303,9 @@ class TestRecursiveScript(JitTestCase):
linear = nn.Linear(10, 10) linear = nn.Linear(10, 10)
test_module_dir(nn.Sequential(conv, linear)) test_module_dir(nn.Sequential(conv, linear))
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))) test_module_dir(
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
)
def test_class_compile(self): def test_class_compile(self):
def other_fn(a: int, b: Tensor) -> Tensor: def other_fn(a: int, b: Tensor) -> Tensor:
@ -296,7 +318,6 @@ class TestRecursiveScript(JitTestCase):
def helper(self, a): def helper(self, a):
return self.x + a + other_fn(self.x, a) return self.x + a + other_fn(self.x, a)
class N(torch.nn.Module): class N(torch.nn.Module):
def forward(self, x): def forward(self, x):
b = B(x) b = B(x)
@ -411,7 +432,7 @@ class TestRecursiveScript(JitTestCase):
def test_module_basic(self): def test_module_basic(self):
class Other(torch.nn.Module): class Other(torch.nn.Module):
__constants__ = ['x'] __constants__ = ["x"]
def __init__(self, x): def __init__(self, x):
super().__init__() super().__init__()
@ -426,7 +447,6 @@ class TestRecursiveScript(JitTestCase):
def forward(self, t): def forward(self, t):
return t + self.x + self.param return t + self.x + self.param
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -439,7 +459,7 @@ class TestRecursiveScript(JitTestCase):
def test_module_function_export(self): def test_module_function_export(self):
class Other(torch.nn.Module): class Other(torch.nn.Module):
__constants__ = ['x'] __constants__ = ["x"]
def __init__(self, x): def __init__(self, x):
super().__init__() super().__init__()
@ -453,7 +473,6 @@ class TestRecursiveScript(JitTestCase):
def forward(self, t): def forward(self, t):
return t + self.x + self.param return t + self.x + self.param
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -473,9 +492,7 @@ class TestRecursiveScript(JitTestCase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sequential = nn.Sequential( self.sequential = nn.Sequential(
Inner(), Inner(), Inner(), nn.Sequential(Inner(), Inner())
Inner(),
nn.Sequential(Inner(), Inner())
) )
self.module_list = nn.ModuleList([Inner(), Inner()]) self.module_list = nn.ModuleList([Inner(), Inner()])
@ -511,12 +528,14 @@ class TestRecursiveScript(JitTestCase):
self.sequential = nn.Sequential( self.sequential = nn.Sequential(
SeluButReluWhenScripted(), SeluButReluWhenScripted(),
SeluButReluWhenScripted(), SeluButReluWhenScripted(),
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()), nn.Sequential(
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
),
shared, shared,
) )
self.module_list = nn.ModuleList([SeluButReluWhenScripted(), self.module_list = nn.ModuleList(
shared, [SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
SeluButReluWhenScripted()]) )
def forward(self, x): def forward(self, x):
for mod in self.module_list: for mod in self.module_list:
@ -553,7 +572,8 @@ class TestRecursiveScript(JitTestCase):
self.assertEqual(obj(1, 2), 3) self.assertEqual(obj(1, 2), 3)
self.assertEqual(obj(1, 2, 3, 4), 10) self.assertEqual(obj(1, 2, 3, 4), 10)
with self.assertRaisesRegex( with self.assertRaisesRegex(
torch.jit.frontend.NotSupportedError, expected_regex="can't take variable number of arguments" torch.jit.frontend.NotSupportedError,
expected_regex="can't take variable number of arguments",
): ):
torch.jit.script(obj) torch.jit.script(obj)
@ -568,7 +588,10 @@ class TestRecursiveScript(JitTestCase):
self.assertEqual(jit_obj(1, 2), 3) self.assertEqual(jit_obj(1, 2), 3)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, expected_regex=re.escape("expected at most 2 argument(s) but received 4 argument(s)") RuntimeError,
expected_regex=re.escape(
"expected at most 2 argument(s) but received 4 argument(s)"
),
): ):
jit_obj(1, 2, 3, 4) jit_obj(1, 2, 3, 4)
@ -598,27 +621,26 @@ class TestRecursiveScript(JitTestCase):
def __getstate__(self): def __getstate__(self):
return (self.a, self.inner) return (self.a, self.inner)
untyped_values = ( untyped_values = (
('my_dict', {"I": "am", "a test": "test"}), ("my_dict", {"I": "am", "a test": "test"}),
('my_float', 2.3), ("my_float", 2.3),
('my_int', 99), ("my_int", 99),
('my_bool', False), ("my_bool", False),
('my_tuple', (1, 2, 3, 4)), ("my_tuple", (1, 2, 3, 4)),
('my_list', [(1, 2), (3, 4)]), ("my_list", [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)), # ('my_tensor', torch.randn(2, 2)),
('my_int_list', [1, 2, 3, 4]), ("my_int_list", [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]), # ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
('my_bool_list', [True, True, False, True]), ("my_bool_list", [True, True, False, True]),
('my_float_list', [1., 2., 3., 4.]), ("my_float_list", [1.0, 2.0, 3.0, 4.0]),
('my_str_list', ['hello', 'bye']), ("my_str_list", ["hello", "bye"]),
) )
typed_values = ( typed_values = (
('my_empty_list', []), ("my_empty_list", []),
('my_empty_dict', {}), ("my_empty_dict", {}),
('my_none', None), ("my_none", None),
('my_object', Foo()), ("my_object", Foo()),
('my_object2', SFoo()), ("my_object2", SFoo()),
) )
class M(torch.nn.Module): class M(torch.nn.Module):
@ -659,11 +681,11 @@ class TestRecursiveScript(JitTestCase):
# since there's no string frontend for Python classes (so the `define`) # since there's no string frontend for Python classes (so the `define`)
# trick doesn't work. # trick doesn't work.
M.__annotations__ = { M.__annotations__ = {
'my_empty_list': List[int], "my_empty_list": List[int],
'my_empty_dict': Dict[str, int], "my_empty_dict": Dict[str, int],
'my_none': Optional[int], "my_none": Optional[int],
'my_object': Foo, "my_object": Foo,
'my_object2': SFoo, "my_object2": SFoo,
} }
m = M() m = M()
@ -694,7 +716,7 @@ class TestRecursiveScript(JitTestCase):
return self.encoder(x) return self.encoder(x)
m = M() m = M()
self.checkModule(m, (torch.randn(5, 5), )) self.checkModule(m, (torch.randn(5, 5),))
def test_inner_traced_module(self): def test_inner_traced_module(self):
class Dummy(nn.Module): class Dummy(nn.Module):
@ -715,12 +737,13 @@ class TestRecursiveScript(JitTestCase):
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2)) dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
dummies = nn.ModuleList([dummy]) dummies = nn.ModuleList([dummy])
model = Model(dummies) model = Model(dummies)
self.checkModule(model, (torch.rand(5, 5), )) self.checkModule(model, (torch.rand(5, 5),))
def test_script_loaded_module(self): def test_script_loaded_module(self):
""" """
Test that we can hold a loaded ScriptModule as a submodule. Test that we can hold a loaded ScriptModule as a submodule.
""" """
class Dummy(nn.Module): class Dummy(nn.Module):
def forward(self, x): def forward(self, x):
return x return x
@ -736,7 +759,7 @@ class TestRecursiveScript(JitTestCase):
def forward(self, input): def forward(self, input):
return self.encoder(input) return self.encoder(input)
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), )) self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
def test_optional_module(self): def test_optional_module(self):
class Dummy(nn.Module): class Dummy(nn.Module):

View File

@ -2,20 +2,23 @@
import os import os
import sys import sys
from typing import List
import torch import torch
from torch.testing import FileCheck from torch.testing import FileCheck
from typing import List
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, freeze_rng_state from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRemoveMutation(JitTestCase): class TestRemoveMutation(JitTestCase):
def test_aten_inplace(self): def test_aten_inplace(self):
@ -26,7 +29,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_not_new_alias) fn = torch.jit.script(test_not_new_alias)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph) FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2]))) self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
@ -38,7 +41,7 @@ class TestRemoveMutation(JitTestCase):
# there is no functional equivalent of x[0] = ... # there is no functional equivalent of x[0] = ...
fn = torch.jit.script(test_no_lowering) fn = torch.jit.script(test_no_lowering)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check("aten::copy_").run(graph) FileCheck().check("aten::copy_").run(graph)
self.assertEqual(fn(), test_no_lowering()) self.assertEqual(fn(), test_no_lowering())
@ -50,7 +53,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_move_before_not_valid) fn = torch.jit.script(test_move_before_not_valid)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check("aten::add_").run(graph) FileCheck().check("aten::add_").run(graph)
self.assertEqual(fn(), test_move_before_not_valid()) self.assertEqual(fn(), test_move_before_not_valid())
@ -63,7 +66,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful) fn = torch.jit.script(test_successful)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::add_").run(graph) FileCheck().check_not("aten::add_").run(graph)
self.assertEqual(test_successful(), fn()) self.assertEqual(test_successful(), fn())
@ -77,7 +80,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_intermediary_use) fn = torch.jit.script(test_intermediary_use)
graph = fn.graph graph = fn.graph
FileCheck().check_count("aten::add_", 2).run(graph) FileCheck().check_count("aten::add_", 2).run(graph)
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
# Unable to remove the second add_ because of the y = x + 4 use # Unable to remove the second add_ because of the y = x + 4 use
# In the future we could duplicating the value of x as a temporary and replacing # In the future we could duplicating the value of x as a temporary and replacing
# its intermediary use (so long as aliasing is safe) # its intermediary use (so long as aliasing is safe)
@ -96,7 +99,7 @@ class TestRemoveMutation(JitTestCase):
out_eager = foo(torch.tensor(5), True) out_eager = foo(torch.tensor(5), True)
foo_script = torch.jit.script(foo) foo_script = torch.jit.script(foo)
FileCheck().check("aten::add_").run(foo_script.graph) FileCheck().check("aten::add_").run(foo_script.graph)
self.run_pass('remove_mutation', foo_script.graph) self.run_pass("remove_mutation", foo_script.graph)
FileCheck().check_not("aten::add_").run(foo_script.graph) FileCheck().check_not("aten::add_").run(foo_script.graph)
self.assertEqual(out_eager, foo_script(torch.tensor(5), True)) self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
@ -113,8 +116,8 @@ class TestRemoveMutation(JitTestCase):
y = x.add_(2) y = x.add_(2)
return y, li return y, li
self.run_pass('inline', foo.graph) self.run_pass("inline", foo.graph)
self.run_pass('remove_mutation', foo.graph) self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph) FileCheck().check("aten::add_").run(foo.graph)
@torch.jit.script @torch.jit.script
@ -126,8 +129,8 @@ class TestRemoveMutation(JitTestCase):
z = x.add_(2) z = x.add_(2)
return z return z
self.run_pass('inline', foo.graph) self.run_pass("inline", foo.graph)
self.run_pass('remove_mutation', foo.graph) self.run_pass("remove_mutation", foo.graph)
FileCheck().check("aten::add_").run(foo.graph) FileCheck().check("aten::add_").run(foo.graph)
def test_special_mapped_op(self): def test_special_mapped_op(self):
@ -140,7 +143,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful) fn = torch.jit.script(test_successful)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph) FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
self.assertEqual(test_successful(), fn()) self.assertEqual(test_successful(), fn())
@ -154,8 +157,8 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(test_successful) fn = torch.jit.script(test_successful)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check_not('aten::fill_').run(graph) FileCheck().check_not("aten::fill_").run(graph)
def normal(): def normal():
# NOTE: For some unknown reason, the # NOTE: For some unknown reason, the
@ -167,7 +170,7 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(normal) fn = torch.jit.script(normal)
graph = fn.graph graph = fn.graph
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
FileCheck().check_not("normal_").run(graph) FileCheck().check_not("normal_").run(graph)
with freeze_rng_state(): with freeze_rng_state():
out_eager = normal() out_eager = normal()
@ -181,10 +184,12 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(successful_remove) fn = torch.jit.script(successful_remove)
graph = fn.graph graph = fn.graph
self.run_pass('loop_unrolling', graph) self.run_pass("loop_unrolling", graph)
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
self.run_pass('constant_propagation', graph) self.run_pass("constant_propagation", graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph) FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), successful_remove()) self.assertEqual(successful_remove(), successful_remove())
def intermediary_use(): def intermediary_use():
@ -196,14 +201,14 @@ class TestRemoveMutation(JitTestCase):
fn = torch.jit.script(intermediary_use) fn = torch.jit.script(intermediary_use)
graph = fn.graph graph = fn.graph
FileCheck().check("append").run(graph) FileCheck().check("append").run(graph)
self.run_pass('remove_mutation', graph) self.run_pass("remove_mutation", graph)
# it is possible to remove the append here but don't currently have the logic for it # it is possible to remove the append here but don't currently have the logic for it
FileCheck().check_not("append").run(graph) FileCheck().check_not("append").run(graph)
self.assertEqual(intermediary_use(), fn()) self.assertEqual(intermediary_use(), fn())
def test_lists_insert(self): def test_lists_insert(self):
def successful_remove(): def successful_remove():
a : List[int] = [] a: List[int] = []
a.insert(0, 1) a.insert(0, 1)
a.insert(0, 2) a.insert(0, 2)
a.insert(-10, 3) a.insert(-10, 3)
@ -215,7 +220,9 @@ class TestRemoveMutation(JitTestCase):
graph = fn.graph graph = fn.graph
torch._C._jit_pass_remove_mutation(graph) torch._C._jit_pass_remove_mutation(graph)
torch._C._jit_pass_constant_propagation(graph) torch._C._jit_pass_constant_propagation(graph)
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph) FileCheck().check("graph").check_next("Constant").check_next("return").run(
graph
)
self.assertEqual(successful_remove(), fn()) self.assertEqual(successful_remove(), fn())
def test_list_indexing_removal(self): def test_list_indexing_removal(self):
@ -271,6 +278,7 @@ class TestRemoveMutation(JitTestCase):
def test_common_pytorch_list_ops(self): def test_common_pytorch_list_ops(self):
for op in ["cat", "stack", "vstack", "hstack", "dstack"]: for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
class OpMod(torch.nn.Module): class OpMod(torch.nn.Module):
def __init__(self, op): def __init__(self, op):
super().__init__() super().__init__()
@ -285,7 +293,7 @@ class TestRemoveMutation(JitTestCase):
torch_op = getattr(torch, op) torch_op = getattr(torch, op)
mod = OpMod(torch_op) mod = OpMod(torch_op)
mod_script = torch.jit.script(mod) mod_script = torch.jit.script(mod)
self.run_pass('remove_mutation', mod_script.forward.graph) self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check_not("aten::add_").run(mod_script.forward.graph) FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
self.assertEqual(mod(), mod_script()) self.assertEqual(mod(), mod_script())
@ -299,7 +307,6 @@ class TestRemoveMutation(JitTestCase):
self.assertEqual(sums, [ten.sum() for ten in result]) self.assertEqual(sums, [ten.sum() for ten in result])
@torch.jit.script @torch.jit.script
def test_multiple_uses(): def test_multiple_uses():
x = torch.tensor([1, 2, 3, 4]) x = torch.tensor([1, 2, 3, 4])
@ -307,5 +314,5 @@ class TestRemoveMutation(JitTestCase):
y = [x, x] y = [x, x]
return torch.cat(y), y return torch.cat(y), y
self.run_pass('remove_mutation', mod_script.forward.graph) self.run_pass("remove_mutation", mod_script.forward.graph)
FileCheck().check("aten::add_").run(test_multiple_uses.graph) FileCheck().check("aten::add_").run(test_multiple_uses.graph)

View File

@ -8,12 +8,12 @@ from typing import NamedTuple, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName, skipIfTorchDynamo from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
@ -439,7 +439,7 @@ class TestSaveLoad(JitTestCase):
global FooTuple # see [local resolution in python] global FooTuple # see [local resolution in python]
class FooTuple(NamedTuple): class FooTuple(NamedTuple):
a: 'int' a: "int"
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor: def forward(self, x: FooTuple) -> torch.Tensor:
@ -608,7 +608,6 @@ class TestSaveLoad(JitTestCase):
self.assertTrue(m_params["bar.bias"].is_cpu) self.assertTrue(m_params["bar.bias"].is_cpu)
self.assertTrue(m_loaded_params["bar.bias"].is_cpu) self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
def test_save_load_with_saved_traced_inputs(self): def test_save_load_with_saved_traced_inputs(self):
""" """
Check that saving and loading with traced inputs works as expected Check that saving and loading with traced inputs works as expected
@ -637,14 +636,18 @@ class TestSaveLoad(JitTestCase):
# Validate that with no input specified the traced inputs are stored # Validate that with no input specified the traced inputs are stored
traced_module = torch.jit.trace(module, input_tensor) traced_module = torch.jit.trace(module, input_tensor)
traced_inputs = list(traced_module.graph.inputs()) traced_inputs = list(traced_module.graph.inputs())
self.assertEqual(traced_module._c._retrieve_traced_inputs()['forward'], [input_tensor]) self.assertEqual(
traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor]
)
with TemporaryFileName() as fname: with TemporaryFileName() as fname:
path = pathlib.Path(fname) path = pathlib.Path(fname)
traced_module.save(path) traced_module.save(path)
loaded_module = torch.jit.load(path, _restore_shapes=True) loaded_module = torch.jit.load(path, _restore_shapes=True)
loaded_inputs = list(loaded_module.graph.inputs()) loaded_inputs = list(loaded_module.graph.inputs())
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
self.assertEqual(traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()) self.assertEqual(
traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()
)
# Validate that if no shapes are requested previous functionality remains # Validate that if no shapes are requested previous functionality remains
loaded_module = torch.jit.load(path) loaded_module = torch.jit.load(path)
loaded_inputs = list(loaded_module.graph.inputs()) loaded_inputs = list(loaded_module.graph.inputs())
@ -672,7 +675,7 @@ class TestSaveLoad(JitTestCase):
"1000": ( "1000": (
torch.tensor([0]), torch.tensor([0]),
torch.tensor([], dtype=torch.int64), torch.tensor([], dtype=torch.int64),
torch.tensor([]) torch.tensor([]),
) )
} }
traced_inputs, loaded_inputs = get_loaded_inputs(input1) traced_inputs, loaded_inputs = get_loaded_inputs(input1)
@ -683,28 +686,32 @@ class TestSaveLoad(JitTestCase):
"1000": ( "1000": (
torch.tensor([0]), torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64), torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0]) torch.tensor([2.0, 3.0]),
) )
} }
traced_inputs, loaded_inputs = get_loaded_inputs(input2) traced_inputs, loaded_inputs = get_loaded_inputs(input2)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list # Testing list
input3 = [torch.tensor([0]), input3 = [
torch.tensor([1500000, 1500004], dtype=torch.int64), torch.tensor([0]),
torch.tensor([2.0, 3.0])] torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0]),
]
traced_inputs, loaded_inputs = get_loaded_inputs(input3) traced_inputs, loaded_inputs = get_loaded_inputs(input3)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
# Testing list of dict of list # Testing list of dict of list
input4 = [{ input4 = [
"1000": ( {
torch.tensor([0]), "1000": (
torch.tensor([1500000, 1500004], dtype=torch.int64), torch.tensor([0]),
torch.tensor([2.0, 3.0]) torch.tensor([1500000, 1500004], dtype=torch.int64),
) torch.tensor([2.0, 3.0]),
}] )
}
]
traced_inputs, loaded_inputs = get_loaded_inputs(input4) traced_inputs, loaded_inputs = get_loaded_inputs(input4)
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type()) self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
@ -715,14 +722,17 @@ class TestSaveLoad(JitTestCase):
Check if the model with string > 4GB can be loaded. Check if the model with string > 4GB can be loaded.
""" """
import psutil import psutil
if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024: if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024:
# Profiled the test execution, and got this number to be safe to run the test # Profiled the test execution, and got this number to be safe to run the test
self.skipTest("Doesn't have enough memory to run test_save_load_large_string_attribute") self.skipTest(
"Doesn't have enough memory to run test_save_load_large_string_attribute"
)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.x = "x" * (2 ** 32 + 1) self.x = "x" * (2**32 + 1)
def forward(self, i) -> int: def forward(self, i) -> int:
return len(self.x) + i.numel() return len(self.x) + i.numel()
@ -793,12 +803,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module): class ContainsBoth(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_module( self.add_module("second", torch.jit.load(second_saved_module))
"second", torch.jit.load(second_saved_module) self.add_module("first", torch.jit.load(first_saved_module))
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x): def forward(self, x):
x = self.first(x) x = self.first(x)
@ -846,12 +852,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module): class ContainsBoth(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_module( self.add_module("second", torch.jit.load(second_saved_module))
"second", torch.jit.load(second_saved_module) self.add_module("first", torch.jit.load(first_saved_module))
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x): def forward(self, x):
x = self.first(x) x = self.first(x)
@ -931,12 +933,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module): class ContainsBoth(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_module( self.add_module("second", torch.jit.load(second_saved_module))
"second", torch.jit.load(second_saved_module) self.add_module("first", torch.jit.load(first_saved_module))
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x): def forward(self, x):
x = self.first(x) x = self.first(x)
@ -1035,12 +1033,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
class ContainsBoth(torch.nn.Module): class ContainsBoth(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.add_module( self.add_module("second", torch.jit.load(second_saved_module))
"second", torch.jit.load(second_saved_module) self.add_module("first", torch.jit.load(first_saved_module))
)
self.add_module(
"first", torch.jit.load(first_saved_module)
)
def forward(self, x): def forward(self, x):
x, named_tuple_1 = self.first(x) x, named_tuple_1 = self.first(x)
@ -1118,18 +1112,18 @@ class TestSaveLoadFlatbuffer(JitTestCase):
first_script_module = torch.jit.script(Foo()) first_script_module = torch.jit.script(Foo())
first_saved_module = io.BytesIO() first_saved_module = io.BytesIO()
torch.jit.save_jit_module_to_flatbuffer( torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
first_script_module, first_saved_module)
first_saved_module.seek(0) first_saved_module.seek(0)
ff_info = torch.jit._serialization.get_flatbuffer_module_info(first_saved_module) ff_info = torch.jit._serialization.get_flatbuffer_module_info(
self.assertEqual(ff_info['bytecode_version'], 9) first_saved_module
self.assertEqual(ff_info['operator_version'], 1) )
self.assertEqual(ff_info['type_names'], set()) self.assertEqual(ff_info["bytecode_version"], 9)
self.assertEqual(ff_info['opname_to_num_args'], {'aten::linear': 3}) self.assertEqual(ff_info["operator_version"], 1)
self.assertEqual(ff_info["type_names"], set())
self.assertEqual(len(ff_info['function_names']), 1) self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3})
self.assertTrue(next(iter(ff_info['function_names'])).endswith('forward'))
self.assertEqual(len(ff_info["function_names"]), 1)
self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward"))
def test_save_load_params_buffers_submodules(self): def test_save_load_params_buffers_submodules(self):
""" """
@ -1179,7 +1173,6 @@ class TestSaveLoadFlatbuffer(JitTestCase):
self.assertEqual(m_name, loaded_name) self.assertEqual(m_name, loaded_name)
self.assertEqual(m_buffer, loaded_buffer) self.assertEqual(m_buffer, loaded_buffer)
def test_save_load_with_extra_files(self): def test_save_load_with_extra_files(self):
""" """
Check that parameters, buffers, and submodules are the same after loading. Check that parameters, buffers, and submodules are the same after loading.
@ -1194,7 +1187,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
extra_files = {"abc.json": b"[1,2,3]"} extra_files = {"abc.json": b"[1,2,3]"}
script_module_io = script_module._save_to_buffer_for_lite_interpreter( script_module_io = script_module._save_to_buffer_for_lite_interpreter(
_extra_files=extra_files, _use_flatbuffer=True) _extra_files=extra_files, _use_flatbuffer=True
)
re_extra_files = {} re_extra_files = {}
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files) torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)

View File

@ -1,20 +1,21 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from itertools import product as product
import io import io
import os import os
import sys import sys
import hypothesis.strategies as st from itertools import product as product
from hypothesis import example, settings, given
from typing import Union from typing import Union
import hypothesis.strategies as st
import torch import torch
from hypothesis import example, given, settings
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit.mobile import _load_for_lite_interpreter from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError( raise RuntimeError(
@ -23,6 +24,7 @@ if __name__ == "__main__":
"instead." "instead."
) )
class TestSaveLoadForOpVersion(JitTestCase): class TestSaveLoadForOpVersion(JitTestCase):
# Helper that returns the module after saving and loading # Helper that returns the module after saving and loading
def _save_load_module(self, m): def _save_load_module(self, m):
@ -53,7 +55,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
node_count = sum(str(n).count(kind) for n in m.graph.nodes()) node_count = sum(str(n).count(kind) for n in m.graph.nodes())
self.assertEqual(node_count, count) self.assertEqual(node_count, count)
""" """
Tests that verify Torchscript remaps aten::div(_) from versions 0-3 Tests that verify Torchscript remaps aten::div(_) from versions 0-3
to call either aten::true_divide(_), if an input is a float type, to call either aten::true_divide(_), if an input is a float type,
@ -62,16 +63,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
div behavior has not yet been updated. div behavior has not yet been updated.
""" """
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor(self, sample_input): def test_versioned_div_tensor(self, sample_input):
def historic_div(self, other): def historic_div(self, other):
if self.is_floating_point() or other.is_floating_point(): if self.is_floating_point() or other.is_floating_point():
return self.true_divide(other) return self.true_divide(other)
return self.divide(other, rounding_mode='trunc') return self.divide(other, rounding_mode="trunc")
# Tensor x Tensor # Tensor x Tensor
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
@ -85,7 +91,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
# Loads historic module # Loads historic module
try: try:
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -108,16 +116,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module, historic_div) _helper(v3_mobile_module, historic_div)
_helper(current_mobile_module, torch.div) _helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor_inplace(self, sample_input): def test_versioned_div_tensor_inplace(self, sample_input):
def historic_div_(self, other): def historic_div_(self, other):
if self.is_floating_point() or other.is_floating_point(): if self.is_floating_point() or other.is_floating_point():
return self.true_divide_(other) return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc') return self.divide_(other, rounding_mode="trunc")
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def forward(self, a, b): def forward(self, a, b):
@ -126,7 +139,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -151,16 +166,25 @@ class TestSaveLoadForOpVersion(JitTestCase):
a = torch.tensor((val_a,)) a = torch.tensor((val_a,))
_helper(current_mobile_module, torch.Tensor.div_) _helper(current_mobile_module, torch.Tensor.div_)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_tensor_out(self, sample_input): def test_versioned_div_tensor_out(self, sample_input):
def historic_div_out(self, other, out): def historic_div_out(self, other, out):
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point(): if (
self.is_floating_point()
or other.is_floating_point()
or out.is_floating_point()
):
return torch.true_divide(self, other, out=out) return torch.true_divide(self, other, out=out)
return torch.divide(self, other, out=out, rounding_mode='trunc') return torch.divide(self, other, out=out, rounding_mode="trunc")
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def forward(self, a, b, out): def forward(self, a, b, out):
@ -168,7 +192,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -179,6 +205,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
b = torch.tensor((val_b,)) b = torch.tensor((val_b,))
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)): for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
def _helper(m, fn): def _helper(m, fn):
fn_result = None fn_result = None
if fn is torch.div: if fn is torch.div:
@ -196,9 +223,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module, historic_div_out) _helper(v3_mobile_module, historic_div_out)
_helper(current_mobile_module, torch.div) _helper(current_mobile_module, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar(self, sample_input): def test_versioned_div_scalar(self, sample_input):
@ -208,7 +240,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
def historic_div_scalar_int(self, other: int): def historic_div_scalar_int(self, other: int):
if self.is_floating_point(): if self.is_floating_point():
return torch.true_divide(self, other) return torch.true_divide(self, other)
return torch.divide(self, other, rounding_mode='trunc') return torch.divide(self, other, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module): class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float): def forward(self, a, b: float):
@ -220,9 +252,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl") pytorch_test_dir
+ "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -249,9 +285,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module_int, historic_div_scalar_int) _helper(v3_mobile_module_int, historic_div_scalar_int)
_helper(current_mobile_module_int, torch.div) _helper(current_mobile_module_int, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar_reciprocal(self, sample_input): def test_versioned_div_scalar_reciprocal(self, sample_input):
@ -261,7 +302,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
def historic_div_scalar_int_reciprocal(self, other: int): def historic_div_scalar_int_reciprocal(self, other: int):
if self.is_floating_point(): if self.is_floating_point():
return other / self return other / self
return torch.divide(other, self, rounding_mode='trunc') return torch.divide(other, self, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module): class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float): def forward(self, a, b: float):
@ -273,9 +314,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -311,9 +356,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
_helper(v3_mobile_module_int, current_mobile_module_int) _helper(v3_mobile_module_int, current_mobile_module_int)
_helper(current_mobile_module_int, torch.div) _helper(current_mobile_module_int, torch.div)
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated @settings(
max_examples=10, deadline=200000
) # A total of 10 examples will be generated
@given( @given(
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0)) sample_input=st.tuples(
st.integers(min_value=5, max_value=199),
st.floats(min_value=5.0, max_value=199.0),
)
) # Generate a pair (integer, float) ) # Generate a pair (integer, float)
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered @example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
def test_versioned_div_scalar_inplace(self, sample_input): def test_versioned_div_scalar_inplace(self, sample_input):
@ -324,7 +374,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
if self.is_floating_point(): if self.is_floating_point():
return self.true_divide_(other) return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc') return self.divide_(other, rounding_mode="trunc")
class MyModuleFloat(torch.nn.Module): class MyModuleFloat(torch.nn.Module):
def forward(self, a, b: float): def forward(self, a, b: float):
@ -338,9 +388,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module_float = _load_for_lite_interpreter( v3_mobile_module_float = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl"
)
v3_mobile_module_int = _load_for_lite_interpreter( v3_mobile_module_int = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
@ -378,14 +432,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
try: try:
v3_mobile_module = _load_for_lite_interpreter( v3_mobile_module = _load_for_lite_interpreter(
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl") pytorch_test_dir
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl"
)
except Exception as e: except Exception as e:
self.skipTest("Failed to load fixture!") self.skipTest("Failed to load fixture!")
current_mobile_module = self._save_load_mobile_module(MyModule) current_mobile_module = self._save_load_mobile_module(MyModule)
def _helper(m, fn): def _helper(m, fn):
vals = (5., 3, 2., 7) vals = (5.0, 3, 2.0, 7)
m_result = m(*vals) m_result = m(*vals)
fn_result = fn(*vals) fn_result = fn(*vals)
for mr, hr in zip(m_result, fn_result): for mr, hr in zip(m_result, fn_result):
@ -395,13 +451,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_linspace(self): def test_versioned_linspace(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]): def forward(
self, a: Union[int, float, complex], b: Union[int, float, complex]
):
c = torch.linspace(a, b, steps=5) c = torch.linspace(a, b, steps=5)
d = torch.linspace(a, b, steps=100) d = torch.linspace(a, b, steps=100)
return c, d return c, d
scripted_module = torch.jit.load( scripted_module = torch.jit.load(
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl") pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
)
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -410,7 +469,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module) current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs: for a, b in sample_inputs:
(output_with_step, output_without_step) = v7_mobile_module(a, b) (output_with_step, output_without_step) = v7_mobile_module(a, b)
(current_with_step, current_without_step) = current_mobile_module(a, b) (current_with_step, current_without_step) = current_mobile_module(a, b)
# when no step is given, should have used 100 # when no step is given, should have used 100
@ -422,10 +481,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_linspace_out(self): def test_versioned_linspace_out(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor): def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.linspace(a, b, steps=100, out=out) return torch.linspace(a, b, steps=100, out=out)
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -433,12 +499,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module) current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ( sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)), (
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)), 3,
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)), 10,
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)), torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
-10,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
4.0,
6.0,
torch.empty((100,), dtype=torch.float64),
torch.empty((100,), dtype=torch.float64),
),
(
3 + 4j,
4 + 5j,
torch.empty((100,), dtype=torch.complex64),
torch.empty((100,), dtype=torch.complex64),
),
) )
for (start, end, out_for_old, out_for_new) in sample_inputs: for start, end, out_for_old, out_for_new in sample_inputs:
output = v7_mobile_module(start, end, out_for_old) output = v7_mobile_module(start, end, out_for_old)
output_current = current_mobile_module(start, end, out_for_new) output_current = current_mobile_module(start, end, out_for_new)
# when no step is given, should have used 100 # when no step is given, should have used 100
@ -448,13 +534,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_logspace(self): def test_versioned_logspace(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]): def forward(
self, a: Union[int, float, complex], b: Union[int, float, complex]
):
c = torch.logspace(a, b, steps=5) c = torch.logspace(a, b, steps=5)
d = torch.logspace(a, b, steps=100) d = torch.logspace(a, b, steps=100)
return c, d return c, d
scripted_module = torch.jit.load( scripted_module = torch.jit.load(
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl") pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
)
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -463,7 +552,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module) current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs: for a, b in sample_inputs:
(output_with_step, output_without_step) = v8_mobile_module(a, b) (output_with_step, output_without_step) = v8_mobile_module(a, b)
(current_with_step, current_without_step) = current_mobile_module(a, b) (current_with_step, current_without_step) = current_mobile_module(a, b)
# when no step is given, should have used 100 # when no step is given, should have used 100
@ -475,10 +564,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
def test_versioned_logspace_out(self): def test_versioned_logspace_out(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor): def forward(
self,
a: Union[int, float, complex],
b: Union[int, float, complex],
out: torch.Tensor,
):
return torch.logspace(a, b, steps=100, out=out) return torch.logspace(a, b, steps=100, out=out)
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter()) buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
buffer.seek(0) buffer.seek(0)
@ -486,12 +582,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
current_mobile_module = self._save_load_mobile_module(Module) current_mobile_module = self._save_load_mobile_module(Module)
sample_inputs = ( sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)), (
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)), 3,
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)), 10,
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)), torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
-10,
10,
torch.empty((100,), dtype=torch.int64),
torch.empty((100,), dtype=torch.int64),
),
(
4.0,
6.0,
torch.empty((100,), dtype=torch.float64),
torch.empty((100,), dtype=torch.float64),
),
(
3 + 4j,
4 + 5j,
torch.empty((100,), dtype=torch.complex64),
torch.empty((100,), dtype=torch.complex64),
),
) )
for (start, end, out_for_old, out_for_new) in sample_inputs: for start, end, out_for_old, out_for_new in sample_inputs:
output = v8_mobile_module(start, end, out_for_old) output = v8_mobile_module(start, end, out_for_old)
output_current = current_mobile_module(start, end, out_for_new) output_current = current_mobile_module(start, end, out_for_new)
# when no step is given, should have used 100 # when no step is given, should have used 100

View File

@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class Sequence(nn.Module): class Sequence(nn.Module):
def __init__(self): def __init__(self):
@ -38,8 +41,8 @@ class Sequence(nn.Module):
outputs = torch.cat(outputs, dim=1) outputs = torch.cat(outputs, dim=1)
return outputs return outputs
class TestScriptProfile(JitTestCase):
class TestScriptProfile(JitTestCase):
def test_basic(self): def test_basic(self):
seq = torch.jit.script(Sequence()) seq = torch.jit.script(Sequence())
p = torch.jit._ScriptProfile() p = torch.jit._ScriptProfile()
@ -57,6 +60,7 @@ class TestScriptProfile(JitTestCase):
@torch.jit.script @torch.jit.script
def fn(): def fn():
_ = seq(torch.rand((10, 100))) _ = seq(torch.rand((10, 100)))
fn() fn()
p.disable() p.disable()
@ -83,7 +87,7 @@ class TestScriptProfile(JitTestCase):
seq = Sequence() seq = Sequence()
@torch.jit.script @torch.jit.script
def fn(max : int): def fn(max: int):
_ = seq(torch.rand((10, max))) _ = seq(torch.rand((10, max)))
p = torch.jit._ScriptProfile() p = torch.jit._ScriptProfile()

View File

@ -3,22 +3,24 @@
import os import os
import sys import sys
import warnings import warnings
from typing import Dict, List, Optional
import torch import torch
from typing import List, Dict, Optional
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase): class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact, # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
# reassigning a non-empty Tuple to an attribute previously typed # reassigning a non-empty Tuple to an attribute previously typed
# as containing an empty Tuple SHOULD fail. See note in `_check.py` # as containing an empty Tuple SHOULD fail. See note in `_check.py`
@ -81,7 +83,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
def test_annotated_class_level_annotation_only(self): def test_annotated_class_level_annotation_only(self):
class M(torch.nn.Module): class M(torch.nn.Module):
x: List[int] x: List[int]
def __init__(self): def __init__(self):
@ -96,10 +97,8 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.checkModule(M(), ([1, 2, 3],)) self.checkModule(M(), ([1, 2, 3],))
assert len(w) == 0 assert len(w) == 0
def test_annotated_class_level_annotation_and_init_annotation(self): def test_annotated_class_level_annotation_and_init_annotation(self):
class M(torch.nn.Module): class M(torch.nn.Module):
x: List[int] x: List[int]
def __init__(self): def __init__(self):
@ -116,7 +115,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
def test_annotated_class_level_jit_annotation(self): def test_annotated_class_level_jit_annotation(self):
class M(torch.nn.Module): class M(torch.nn.Module):
x: List[int] x: List[int]
def __init__(self): def __init__(self):
@ -141,12 +139,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Tried to set nonexistent attribute", RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_empty_dict(self): def test_annotated_empty_dict(self):
@ -159,12 +160,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Tried to set nonexistent attribute", RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_empty_optional(self): def test_annotated_empty_optional(self):
@ -177,12 +181,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Wrong type for attribute assignment", RuntimeError, "Wrong type for attribute assignment", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_with_jit_empty_list(self): def test_annotated_with_jit_empty_list(self):
@ -195,12 +202,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Tried to set nonexistent attribute", RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_with_jit_empty_dict(self): def test_annotated_with_jit_empty_dict(self):
@ -213,12 +223,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Tried to set nonexistent attribute", RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_with_jit_empty_optional(self): def test_annotated_with_jit_empty_optional(self):
@ -231,12 +244,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Wrong type for attribute assignment", RuntimeError, "Wrong type for attribute assignment", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())
def test_annotated_with_torch_jit_import(self): def test_annotated_with_torch_jit_import(self):
@ -251,10 +267,13 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
self.x = x self.x = x
return 1 return 1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Wrong type for attribute assignment", RuntimeError, "Wrong type for attribute assignment", "self.x = x"
"self.x = x"): ):
with self.assertWarnsRegex(UserWarning, "doesn't support " with self.assertWarnsRegex(
"instance-level annotations on " UserWarning,
"empty non-base types"): "doesn't support "
"instance-level annotations on "
"empty non-base types",
):
torch.jit.script(M()) torch.jit.script(M())

View File

@ -2,19 +2,22 @@
import os import os
import sys import sys
from typing import List
import torch import torch
from typing import List
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests that Python slice class is supported in TorchScript # Tests that Python slice class is supported in TorchScript
class TestSlice(JitTestCase): class TestSlice(JitTestCase):
@ -22,7 +25,9 @@ class TestSlice(JitTestCase):
def slice_kwarg(x: List[int]): def slice_kwarg(x: List[int]):
return x[slice(1, stop=2)] return x[slice(1, stop=2)]
with self.assertRaisesRegex(RuntimeError, "Slice does not accept any keyword arguments"): with self.assertRaisesRegex(
RuntimeError, "Slice does not accept any keyword arguments"
):
torch.jit.script(slice_kwarg) torch.jit.script(slice_kwarg)
def test_slice_three_nones(self): def test_slice_three_nones(self):
@ -46,11 +51,13 @@ class TestSlice(JitTestCase):
def test_slice_stop_only(self): def test_slice_stop_only(self):
def fn(x: List[int]): def fn(x: List[int]):
return x[slice(5)] return x[slice(5)]
self.checkScript(fn, (range(10),)) self.checkScript(fn, (range(10),))
def test_slice_stop_only_with_nones(self): def test_slice_stop_only_with_nones(self):
def fn(x: List[int]): def fn(x: List[int]):
return x[slice(None, 5, None)] return x[slice(None, 5, None)]
self.checkScript(fn, (range(10),)) self.checkScript(fn, (range(10),))
def test_slice_start_stop(self): def test_slice_start_stop(self):
@ -136,8 +143,8 @@ class TestSlice(JitTestCase):
num_outputs = {len(x.output().type().elements()) for x in slices} num_outputs = {len(x.output().type().elements()) for x in slices}
# there should be only one tupleSlice with length of 2 # there should be only one tupleSlice with length of 2
self.assertTrue(num_outputs == {2}) self.assertTrue(num_outputs == {2})
self.run_pass('lower_all_tuples', tuple_graph) self.run_pass("lower_all_tuples", tuple_graph)
self.assertTrue('Tuple' not in str(tuple_graph)) self.assertTrue("Tuple" not in str(tuple_graph))
def test_module_list_slicing(self): def test_module_list_slicing(self):
class Bar(torch.nn.Module): class Bar(torch.nn.Module):

View File

@ -1,8 +1,9 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import io import io
import torch
import unittest import unittest
import torch
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
@ -70,9 +71,7 @@ class TestSparse(JitTestCase):
self.a = torch.rand(4, 4).to_sparse_csr() self.a = torch.rand(4, 4).to_sparse_csr()
self.b = torch.rand(4, 4).to_sparse_csr() self.b = torch.rand(4, 4).to_sparse_csr()
def forward(self, x): def forward(self, x):
return x.matmul(self.a).matmul(self.b) return x.matmul(self.a).matmul(self.b)
x = torch.rand(4, 4).to_sparse_csr() x = torch.rand(4, 4).to_sparse_csr()

View File

@ -2,65 +2,76 @@
import os import os
import sys import sys
from typing import List
import torch import torch
from typing import List
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestStringFormatting(JitTestCase): class TestStringFormatting(JitTestCase):
def test_modulo_operator(self): def test_modulo_operator(self):
def fn(dividend: int, divisor: int) -> int: def fn(dividend: int, divisor: int) -> int:
return dividend % divisor return dividend % divisor
self.checkScript(fn, (5, 2)) self.checkScript(fn, (5, 2))
def test_string_interpolation_with_string_placeholder_and_string_variable(self): def test_string_interpolation_with_string_placeholder_and_string_variable(self):
def fn(arg1: str): def fn(arg1: str):
return "%s in template" % arg1 return "%s in template" % arg1
self.checkScript(fn, ("foo",)) self.checkScript(fn, ("foo",))
def test_string_interpolation_with_string_placeholder_and_format_string_variable(self): def test_string_interpolation_with_string_placeholder_and_format_string_variable(
self,
):
def fn(arg1: str): def fn(arg1: str):
return arg1 % "foo" return arg1 % "foo"
self.checkScript(fn, ("%s in template",)) self.checkScript(fn, ("%s in template",))
def test_string_interpolation_with_double_percent_in_string(self): def test_string_interpolation_with_double_percent_in_string(self):
def fn(arg1: str): def fn(arg1: str):
return "%s in template %%" % arg1 return "%s in template %%" % arg1
self.checkScript(fn, ("foo",)) self.checkScript(fn, ("foo",))
def test_string_interpolation_with_percent_in_string(self): def test_string_interpolation_with_percent_in_string(self):
@torch.jit.script @torch.jit.script
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%s in template %" % arg1 # noqa: F501 return "%s in template %" % arg1 # noqa: F501
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Incomplete format specifier", RuntimeError, "Incomplete format specifier", '"%s in template %" % arg1'
"\"%s in template %\" % arg1"): ):
fn("foo") fn("foo")
def test_string_interpolation_with_string_placeholder_and_digit_variable(self): def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%s in template" % arg1 return "%s in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_digit_placeholder_and_digit_variable(self): def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%d in template" % arg1 return "%d in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_alternate_digit_placeholder(self): def test_string_interpolation_with_alternate_digit_placeholder(self):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%i in template" % arg1 return "%i in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_digit_placeholder_and_string_variable(self): def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
@ -68,9 +79,11 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%d in template" % arg1 return "%d in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"%d requires a number for formatting, but got String", RuntimeError,
"\"%d in template\" % arg1"): "%d requires a number for formatting, but got String",
'"%d in template" % arg1',
):
fn("1") fn("1")
def test_string_interpolation_with_exponent_placeholder_and_string_variable(self): def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
@ -78,39 +91,51 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%e in template" % arg1 return "%e in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"%e requires a number for formatting, but got String", RuntimeError,
"\"%e in template\" % arg1"): "%e requires a number for formatting, but got String",
'"%e in template" % arg1',
):
fn("1") fn("1")
def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self): def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(
self,
):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%e in template" % arg1 return "%e in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self): def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(
self,
):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%E in template" % arg1 return "%E in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_float_placeholder_and_float_variable(self): def test_string_interpolation_with_float_placeholder_and_float_variable(self):
def fn(arg1: float) -> str: def fn(arg1: float) -> str:
return "%f in template" % arg1 return "%f in template" % arg1
self.checkScript(fn, (1.0,)) self.checkScript(fn, (1.0,))
def test_string_interpolation_with_float_placeholder_and_digit_variable(self): def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%f in template" % arg1 return "%f in template" % arg1
self.checkScript(fn, (1,)) self.checkScript(fn, (1,))
def test_string_interpolation_with_char_placeholder_and_char_variable(self): def test_string_interpolation_with_char_placeholder_and_char_variable(self):
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%c in template" % arg1 return "%c in template" % arg1
self.checkScript(fn, ("a",)) self.checkScript(fn, ("a",))
def test_string_interpolation_with_char_placeholder_and_digit_variable(self): def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
def fn(arg1: int) -> str: def fn(arg1: int) -> str:
return "%c in template" % arg1 return "%c in template" % arg1
self.checkScript(fn, (97,)) self.checkScript(fn, (97,))
def test_string_interpolation_with_char_placeholder_and_true_string_variable(self): def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
@ -118,19 +143,23 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%c in template" % arg1 return "%c in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"%c requires an int or char for formatting, but got String", RuntimeError,
"\"%c in template\" % arg1"): "%c requires an int or char for formatting, but got String",
'"%c in template" % arg1',
):
fn("foo") fn("foo")
def test_string_interpolation_with_multiple_placeholders(self): def test_string_interpolation_with_multiple_placeholders(self):
def fn(arg1: str, arg2: int, arg3: float) -> str: def fn(arg1: str, arg2: int, arg3: float) -> str:
return "%s %d %f in template" % (arg1, arg2, arg3) return "%s %d %f in template" % (arg1, arg2, arg3)
self.checkScript(fn, ("foo", 1, 1)) self.checkScript(fn, ("foo", 1, 1))
def test_string_interpolation_with_subscript(self): def test_string_interpolation_with_subscript(self):
def fn(arg1: List[str]) -> str: def fn(arg1: List[str]) -> str:
return "%s in template" % arg1[0] return "%s in template" % arg1[0]
self.checkScript(fn, (["foo", "bar"],)) self.checkScript(fn, (["foo", "bar"],))
def test_string_interpolation_with_too_few_arguments(self): def test_string_interpolation_with_too_few_arguments(self):
@ -138,27 +167,33 @@ class TestStringFormatting(JitTestCase):
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%s %s in template" % arg1 return "%s %s in template" % arg1
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Too few arguments for format string", RuntimeError,
"\"%s %s in template\" % arg1"): "Too few arguments for format string",
'"%s %s in template" % arg1',
):
fn("foo") fn("foo")
def test_string_interpolation_with_too_many_arguments(self): def test_string_interpolation_with_too_many_arguments(self):
@torch.jit.script @torch.jit.script
def fn(arg1: str, arg2: str) -> str: def fn(arg1: str, arg2: str) -> str:
return "%s in template" % (arg1, arg2) # noqa: F507 return "%s in template" % (arg1, arg2) # noqa: F507
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"Too many arguments for format string", RuntimeError,
"\"%s in template\" % (arg1, arg2"): "Too many arguments for format string",
'"%s in template" % (arg1, arg2',
):
fn("foo", "bar") fn("foo", "bar")
def test_string_interpolation_with_unknown_format_specifier(self): def test_string_interpolation_with_unknown_format_specifier(self):
@torch.jit.script @torch.jit.script
def fn(arg1: str) -> str: def fn(arg1: str) -> str:
return "%a in template" % arg1 # noqa: F501 return "%a in template" % arg1 # noqa: F501
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"The specifier %a is not supported in TorchScript format strings", RuntimeError,
"\"%a in template\" % arg1"): "The specifier %a is not supported in TorchScript format strings",
'"%a in template" % arg1',
):
fn("foo") fn("foo")

View File

@ -3,29 +3,36 @@
import operator import operator
import unittest import unittest
from textwrap import dedent from textwrap import dedent
from typing import Any, List
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
from torch.testing._internal.common_utils import make_tensor from torch.testing._internal.common_utils import make_tensor
from torch.testing._internal.jit_utils import JitTestCase, execWrapper from torch.testing._internal.jit_utils import execWrapper, JitTestCase
from typing import List, Any
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# XXX: still in prototype # XXX: still in prototype
class TestSymbolicShapeAnalysis(JitTestCase): class TestSymbolicShapeAnalysis(JitTestCase):
def setUp(self): def setUp(self):
super(JitTestCase, self).setUp() super(JitTestCase, self).setUp()
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() self.prev_symbolic_shapes_test_enabled = (
torch._C._jit_symbolic_shapes_test_mode_enabled()
)
torch._C._jit_set_symbolic_shapes_test_mode(True) torch._C._jit_set_symbolic_shapes_test_mode(True)
def tearDown(self): def tearDown(self):
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled) torch._C._jit_set_symbolic_shapes_test_mode(
self.prev_symbolic_shapes_test_enabled
)
def test_shape_analysis(self): def test_shape_analysis(self):
@torch.jit.script @torch.jit.script
@ -115,7 +122,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def neg_to_one(li): def neg_to_one(li):
return [elem if elem >= 0 else -1 for elem in li] return [elem if elem >= 0 else -1 for elem in li]
self.assertEqual(neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]) self.assertEqual(
neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]
)
if_out = next(foo.graph.findNode("prim::If").outputs()) if_out = next(foo.graph.findNode("prim::If").outputs())
self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1]) self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1])
@ -135,9 +144,7 @@ class TestSymbolicShapeAnalysis(JitTestCase):
y = x.mul_(2) y = x.mul_(2)
return y return y
unary_ops = [ unary_ops = [mul_inplace]
mul_inplace
]
for fn in unary_ops: for fn in unary_ops:
# t = torch.jit.trace(fn, torch.rand([4, 4])) # For some reason tracing is erroring out. # t = torch.jit.trace(fn, torch.rand([4, 4])) # For some reason tracing is erroring out.
t = torch.jit.script(fn) t = torch.jit.script(fn)
@ -202,7 +209,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1])) inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
torch._C._jit_pass_propagate_shapes_on_graph(graph) torch._C._jit_pass_propagate_shapes_on_graph(graph)
self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]) self.assertEqual(
next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]
)
def test_adaptive_avg_pool2d(self): def test_adaptive_avg_pool2d(self):
inps = [ inps = [
@ -227,25 +236,105 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True) self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
def test_conv_deconv(self): def test_conv_deconv(self):
for inp_shape, weight_shape, bias, stride, padding, output_padding, dilation, groups, mod in [ for (
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d), inp_shape,
([32, 16, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv_transpose1d), weight_shape,
([1, 32, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv2d), bias,
([1, 30, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv_transpose2d), stride,
([3, 14, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 2, 1, 2, torch.nn.functional.conv3d), padding,
([3, 2, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 0, 1, 2, torch.nn.functional.conv_transpose3d)]: output_padding,
dilation,
groups,
mod,
) in [
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
(
[32, 16, 10],
[16, 3, 3],
None,
2,
2,
1,
1,
2,
torch.nn.functional.conv_transpose1d,
),
(
[1, 32, 5, 10],
[30, 16, 3, 3],
None,
[2, 2],
[0, 0],
0,
1,
2,
torch.nn.functional.conv2d,
),
(
[1, 30, 5, 10],
[30, 16, 3, 3],
None,
[2, 2],
[0, 0],
0,
1,
2,
torch.nn.functional.conv_transpose2d,
),
(
[3, 14, 10, 66, 55],
[2, 7, 7, 4, 4],
None,
1,
1,
2,
1,
2,
torch.nn.functional.conv3d,
),
(
[3, 2, 10, 66, 55],
[2, 7, 7, 4, 4],
None,
1,
1,
0,
1,
2,
torch.nn.functional.conv_transpose3d,
),
]:
inp = torch.rand(inp_shape) inp = torch.rand(inp_shape)
weight = torch.rand(weight_shape) weight = torch.rand(weight_shape)
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]: if mod in [
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
]:
res = mod(inp, weight, bias, stride, padding, dilation, groups).size() res = mod(inp, weight, bias, stride, padding, dilation, groups).size()
else: else:
res = mod(inp, weight, bias, stride, padding, output_padding, dilation, groups).size() res = mod(
inp, weight, bias, stride, padding, output_padding, dilation, groups
).size()
def foo(inp, weight): def foo(inp, weight):
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]: if mod in [
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
]:
return mod(inp, weight, bias, stride, padding, dilation, groups) return mod(inp, weight, bias, stride, padding, dilation, groups)
else: else:
return mod(inp, weight, bias, stride, padding, output_padding, dilation, groups) return mod(
inp,
weight,
bias,
stride,
padding,
output_padding,
dilation,
groups,
)
fn = torch.jit.trace(foo, (inp, weight)) fn = torch.jit.trace(foo, (inp, weight))
torch._C._jit_erase_non_input_shape_information(fn.graph) torch._C._jit_erase_non_input_shape_information(fn.graph)
@ -280,33 +369,58 @@ class TestSymbolicShapeAnalysis(JitTestCase):
] ]
for inp in inps: for inp in inps:
funcs_template = dedent(''' funcs_template = dedent(
"""
def func(): def func():
return torch.arange({args}) return torch.arange({args})
''') """
)
inp_s = str(inp)[1:-1] # remove tuple parens inp_s = str(inp)[1:-1] # remove tuple parens
funcs_str = funcs_template.format(args=inp_s) funcs_str = funcs_template.format(args=inp_s)
scope = {} scope = {}
execWrapper(funcs_str, globals(), scope) execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str) cu = torch.jit.CompilationUnit(funcs_str)
self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False) self.checkShapeAnalysis(
list(cu.func().size()),
cu.func.graph,
assert_propagation=True,
constant_prop=False,
)
def test_shape_embedding_bag(self): def test_shape_embedding_bag(self):
# TODO: merge into opinfos, having difficulties there # TODO: merge into opinfos, having difficulties there
with torch.no_grad(): with torch.no_grad():
def make_arg(shape, low=None, high=None): def make_arg(shape, low=None, high=None):
return make_tensor(shape, device='cpu', dtype=torch.int64, return make_tensor(
low=low, high=high, requires_grad=False) shape,
device="cpu",
dtype=torch.int64,
low=low,
high=high,
requires_grad=False,
)
nn_inps = ( nn_inps = (
(make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)), (
make_arg((40,), 0, 9),
torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0),
),
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)), (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
(make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)), (make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)),
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)), (make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
(make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)), (make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
(make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2., (
norm_type=.5, scale_grad_by_freq=False, sparse=True)), make_arg((2,), 0, 1),
torch.nn.Embedding.from_pretrained(
torch.arange(6.0).view(2, 3),
max_norm=2.0,
norm_type=0.5,
scale_grad_by_freq=False,
sparse=True,
),
),
) )
for inp, module in nn_inps: for inp, module in nn_inps:
@ -326,14 +440,16 @@ class TestSymbolicShapeAnalysis(JitTestCase):
fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False) fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False) self.checkShapeAnalysis(
out_size, fn.graph, assert_propagation=True, constant_prop=False
)
def test_shape_concat(self): def test_shape_concat(self):
# TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR # TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False) sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)
class CatMod(nn.Module): class CatMod(nn.Module):
__constants__ = ['dim'] __constants__ = ["dim"]
def __init__(self, dim=0): def __init__(self, dim=0):
super().__init__() super().__init__()
@ -374,16 +490,23 @@ class TestSymbolicShapeAnalysis(JitTestCase):
# Also, as the return shapes are the input, weight, and bias shape, there is no point # Also, as the return shapes are the input, weight, and bias shape, there is no point
# in a really complicated test # in a really complicated test
input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True) input = torch.randn(
weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True) (16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True
)
weight = torch.randn(
(8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True
)
out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu") out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
@torch.jit.script @torch.jit.script
def conv_bwd(input, weight, grad): def conv_bwd(input, weight, grad):
bias_sizes = [8, ] bias_sizes = [
8,
]
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args) return torch.ops.aten.convolution_backward(
grad, input, weight, bias_sizes, *args
)
self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad)) self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
@ -391,15 +514,19 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def conv_bwd_2(input, weight, grad): def conv_bwd_2(input, weight, grad):
bias_sizes = None bias_sizes = None
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True]) args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args) return torch.ops.aten.convolution_backward(
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad)) grad, input, weight, bias_sizes, *args
)
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
def test_returning_input_symbolic_shapes(self): def test_returning_input_symbolic_shapes(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
inps = list(mm.graph.inputs()) inps = list(mm.graph.inputs())
inps[1].setType(inps[1].type().with_sizes([None, None, None, None])) inps[1].setType(inps[1].type().with_sizes([None, None, None, None]))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
)
g = shape_compute_graph.partial_eval_shape_graph() g = shape_compute_graph.partial_eval_shape_graph()
# to make into a jit function cant have multiple outputs # to make into a jit function cant have multiple outputs
g.makeMultiOutputIntoTuple() g.makeMultiOutputIntoTuple()
@ -412,8 +539,12 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def test_partial_eval_graph_conv(self): def test_partial_eval_graph_conv(self):
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval())) mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph) shape_compute_graph = (
output_sizes = mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes() torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
)
output_sizes = (
mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
)
# calculating 0, 2 and 3 index # calculating 0, 2 and 3 index
for i in [0, 2, 3]: for i in [0, 2, 3]:
self.assertTrue(output_sizes[i] < 0) self.assertTrue(output_sizes[i] < 0)
@ -428,7 +559,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
for o, oe in zip(output, output_eager[0:1] + output_eager[2:]): for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
self.assertEqual(o, oe) self.assertEqual(o, oe)
def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs): def checkSymShapeCompute(
self, shape_compute_graph, nodes, node_output_sizes, shape_inputs
):
g = shape_compute_graph.partial_eval_shape_graph() g = shape_compute_graph.partial_eval_shape_graph()
self.assertTrue(len(list(g.inputs())) == len(shape_inputs)) self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim() output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
@ -451,27 +584,49 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(sym_outputs[sym_shape_index], output_shape[i]) self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
def test_partial_eval_stitching(self): def test_partial_eval_stitching(self):
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) conv1 = torch.nn.Conv2d(
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) )
max_pool = torch.nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
)
conv2 = nn.Conv2d(
64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())) mod = torch.jit.freeze(
torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())
)
conv1_output = conv1(torch.rand(1, 3, 224, 224)) conv1_output = conv1(torch.rand(1, 3, 224, 224))
max_pool_output = max_pool(conv1_output) max_pool_output = max_pool(conv1_output)
conv2_output = conv2(max_pool_output) conv2_output = conv2(max_pool_output)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) shape_compute_graph = (
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d")) torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()] )
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)) nodes = [mod.graph.findNode("aten::max_pool2d")] + list(
mod.graph.findAllNodes("aten::conv2d")
)
output_shapes = [
max_pool_output.size(),
conv1_output.size(),
conv2_output.size(),
]
self.checkSymShapeCompute(
shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)
)
def test_refinement_through_graph_stitching(self): def test_refinement_through_graph_stitching(self):
class TwoConvs(torch.nn.Module): class TwoConvs(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) self.conv1 = torch.nn.Conv2d(
self.conv2 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
self.conv2 = torch.nn.Conv2d(
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
def forward(self, x): def forward(self, x):
a = self.conv1(x) a = self.conv1(x)
@ -495,18 +650,29 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(out1, out2) self.assertEqual(out1, out2)
def test_stitching_multi_output(self): def test_stitching_multi_output(self):
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True) max_pool = torch.nn.MaxPool2d(
kernel_size=3,
stride=2,
padding=1,
dilation=1,
ceil_mode=False,
return_indices=True,
)
tensor = torch.rand(1, 3, 224, 224) tensor = torch.rand(1, 3, 224, 224)
mod = torch.jit.trace(max_pool, (tensor,)) mod = torch.jit.trace(max_pool, (tensor,))
mod = torch.jit.freeze(mod.eval()) mod = torch.jit.freeze(mod.eval())
inp = list(mod.graph.inputs())[1] inp = list(mod.graph.inputs())[1]
inp.setType(inp.type().with_sizes([None, None, None, None])) inp.setType(inp.type().with_sizes([None, None, None, None]))
output_tensor = list(mod(tensor)[0].size()) output_tensor = list(mod(tensor)[0].size())
self.run_pass('lower_all_tuples', mod.graph) self.run_pass("lower_all_tuples", mod.graph)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph) shape_compute_graph = (
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
)
max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices") max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
outs = list(max_pool_node.outputs()) outs = list(max_pool_node.outputs())
self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()) self.assertEqual(
outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()
)
g = shape_compute_graph.partial_eval_shape_graph() g = shape_compute_graph.partial_eval_shape_graph()
# to make into a jit function cant have multiple outputs # to make into a jit function cant have multiple outputs
g.makeMultiOutputIntoTuple() g.makeMultiOutputIntoTuple()
@ -528,7 +694,6 @@ class TestSymbolicShapeAnalysis(JitTestCase):
self.assertEqual(out, [-2, -3]) self.assertEqual(out, [-2, -3])
def test_stitching_concat(self): def test_stitching_concat(self):
@torch.jit.script @torch.jit.script
def foo1(a, b, x, y): def foo1(a, b, x, y):
return (a / b) + torch.cat([x, y]) return (a / b) + torch.cat([x, y])
@ -542,15 +707,25 @@ class TestSymbolicShapeAnalysis(JitTestCase):
for inp in foo.graph.inputs(): for inp in foo.graph.inputs():
inp.setType(inp.type().with_sizes([None, None])) inp.setType(inp.type().with_sizes([None, None]))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph) shape_compute_graph = (
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")] torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
foo.graph
)
)
nodes = (
[g.findNode("aten::div")]
+ [g.findNode("aten::add")]
+ [g.findNode("aten::cat")]
)
inps = [1, 10], [20, 10], [15, 1], [5, 1] inps = [1, 10], [20, 10], [15, 1], [5, 1]
output_shapes = [[20, 10], [20, 10], [20, 1]] output_shapes = [[20, 10], [20, 10], [20, 1]]
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps) self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
@unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python") @unittest.skipIf(
not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python"
)
def test_shape_function_includes(self): def test_shape_function_includes(self):
inp_shape = [1, 16, 5, 10] inp_shape = [1, 16, 5, 10]
weight_shape = [33, 16, 3, 3] weight_shape = [33, 16, 3, 3]
@ -559,7 +734,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
padding = [0, 0] padding = [0, 0]
dilation = [1, 1] dilation = [1, 1]
groups = 1 groups = 1
res = torch.jit._shapes.conv2d(inp_shape, weight_shape, bias, stride, padding, dilation, groups) res = torch.jit._shapes.conv2d(
inp_shape, weight_shape, bias, stride, padding, dilation, groups
)
self.assertEqual(res, [1, 33, 2, 4]) self.assertEqual(res, [1, 33, 2, 4])
m1_shape = [10, 20] m1_shape = [10, 20]
@ -580,8 +757,11 @@ class TestSymbolicShapeAnalysis(JitTestCase):
def wrong_input_types(x, y): def wrong_input_types(x, y):
x: List[int] = [] x: List[int] = []
return x return x
with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"): with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph) torch._C._jit_register_shape_compute_graph_for_node(
node, wrong_input_types.graph
)
@torch.jit.script @torch.jit.script
def wrong_output_types(x: List[int], y: List[int]): def wrong_output_types(x: List[int], y: List[int]):
@ -589,7 +769,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
return x return x
with self.assertRaisesRegex(RuntimeError, "but got graph_type"): with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph) torch._C._jit_register_shape_compute_graph_for_node(
node, wrong_output_types.graph
)
@torch.jit.script @torch.jit.script
def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any): def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
@ -597,7 +779,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
return x return x
with self.assertRaises(RuntimeError) as error: with self.assertRaises(RuntimeError) as error:
torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph) torch._C._jit_register_shape_compute_graph_for_node(
node, too_many_inputs.graph
)
self.assertTrue("fewer arguments than schema" in str(error.exception)) self.assertTrue("fewer arguments than schema" in str(error.exception))
@ -608,9 +792,22 @@ class TestSymbolicShapeAnalysis(JitTestCase):
inputs = list(foo.graph.inputs()) inputs = list(foo.graph.inputs())
inputs[0].setType(inputs[0].type().with_sizes([8, 2])) inputs[0].setType(inputs[0].type().with_sizes([8, 2]))
inputs[1].setType(inputs[1].type().with_sizes([8,])) inputs[1].setType(
inputs[1]
.type()
.with_sizes(
[
8,
]
)
)
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
self.assertEqual(next(foo.graph.outputs()).type().sizes(), [8,]) self.assertEqual(
next(foo.graph.outputs()).type().sizes(),
[
8,
],
)
def test_squeeze_dims(self): def test_squeeze_dims(self):
@torch.jit.script @torch.jit.script

View File

@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTensorCreationOps(JitTestCase): class TestTensorCreationOps(JitTestCase):
""" """
@ -27,7 +30,7 @@ class TestTensorCreationOps(JitTestCase):
# as integers, which are not comparable against eager torch.dtype. # as integers, which are not comparable against eager torch.dtype.
assert perm.dtype == torch.int64 assert perm.dtype == torch.int64
self.checkScript(randperm, (3, )) self.checkScript(randperm, (3,))
def test_randperm_specifed_dtype(self): def test_randperm_specifed_dtype(self):
def randperm(x: int): def randperm(x: int):
@ -36,7 +39,7 @@ class TestTensorCreationOps(JitTestCase):
# as integers, which are not comparable against eager torch.dtype. # as integers, which are not comparable against eager torch.dtype.
assert perm.dtype == torch.float assert perm.dtype == torch.float
self.checkScript(randperm, (3, )) self.checkScript(randperm, (3,))
def test_triu_indices_default_dtype(self): def test_triu_indices_default_dtype(self):
def triu_indices(rows: int, cols: int): def triu_indices(rows: int, cols: int):

View File

@ -8,8 +8,8 @@ import torch
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError( raise RuntimeError(
@ -18,6 +18,7 @@ if __name__ == "__main__":
"instead." "instead."
) )
class TestTensorMethods(JitTestCase): class TestTensorMethods(JitTestCase):
def test_getitem(self): def test_getitem(self):
def tensor_getitem(inp: torch.Tensor): def tensor_getitem(inp: torch.Tensor):
@ -25,7 +26,7 @@ class TestTensorMethods(JitTestCase):
return inp.__getitem__(indices) return inp.__getitem__(indices)
inp = torch.rand(3, 4) inp = torch.rand(3, 4)
self.checkScript(tensor_getitem, (inp, )) self.checkScript(tensor_getitem, (inp,))
scripted = torch.jit.script(tensor_getitem) scripted = torch.jit.script(tensor_getitem)
FileCheck().check("aten::index").run(scripted.graph) FileCheck().check("aten::index").run(scripted.graph)
@ -35,5 +36,6 @@ class TestTensorMethods(JitTestCase):
return inp.__getitem__() return inp.__getitem__()
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"): RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
):
torch.jit.script(tensor_getitem_invalid) torch.jit.script(tensor_getitem_invalid)

View File

@ -1,27 +1,27 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import copy
import io import io
import os import os
import sys import sys
import copy
import unittest import unittest
from typing import Optional
import torch import torch
from typing import Optional
from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.testing._internal.common_utils import skipIfTorchDynamo
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing import FileCheck
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE, IS_FBCODE,
IS_MACOS, IS_MACOS,
IS_SANDCASTLE, IS_SANDCASTLE,
IS_WINDOWS, IS_WINDOWS,
find_library_location,
) )
from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__": if __name__ == "__main__":
raise RuntimeError( raise RuntimeError(
@ -30,14 +30,15 @@ if __name__ == "__main__":
"instead." "instead."
) )
@skipIfTorchDynamo("skipping as a precaution") @skipIfTorchDynamo("skipping as a precaution")
class TestTorchbind(JitTestCase): class TestTorchbind(JitTestCase):
def setUp(self): def setUp(self):
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE: if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
raise unittest.SkipTest("non-portable load_library call used in test") raise unittest.SkipTest("non-portable load_library call used in test")
lib_file_path = find_library_location('libtorchbind_test.so') lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS: if IS_WINDOWS:
lib_file_path = find_library_location('torchbind_test.dll') lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path)) torch.ops.load_library(str(lib_file_path))
def test_torchbind(self): def test_torchbind(self):
@ -50,15 +51,17 @@ class TestTorchbind(JitTestCase):
val = torch.classes._TorchScriptTesting._Foo(5, 3) val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment(1) val.increment(1)
return val return val
test_equality(f, lambda x: x) test_equality(f, lambda x: x)
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"): with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
val = torch.classes._TorchScriptTesting._Foo(5, 3) val = torch.classes._TorchScriptTesting._Foo(5, 3)
val.increment('foo') val.increment("foo")
def f(): def f():
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"]) ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
return ss.pop() return ss.pop()
test_equality(f, lambda x: x) test_equality(f, lambda x: x)
def f(): def f():
@ -66,6 +69,7 @@ class TestTorchbind(JitTestCase):
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"]) ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
ss1.push(ss2.pop()) ss1.push(ss2.pop())
return ss1.pop() + ss2.pop() return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x) test_equality(f, lambda x: x)
# test nn module with prepare_scriptable function # test nn module with prepare_scriptable function
@ -116,8 +120,11 @@ class TestTorchbind(JitTestCase):
scripted = torch.jit.script(foo) scripted = torch.jit.script(foo)
# Ensure we are creating the object and calling __init__ # Ensure we are creating the object and calling __init__
# rather than calling the __init__wrapper nonsense # rather than calling the __init__wrapper nonsense
fc = FileCheck().check('prim::CreateObject()')\ fc = (
.check('prim::CallMethod[name="__init__"]') FileCheck()
.check("prim::CreateObject()")
.check('prim::CallMethod[name="__init__"]')
)
fc.run(str(scripted.graph)) fc.run(str(scripted.graph))
out = scripted() out = scripted()
self.assertEqual(out.pop(), "mom") self.assertEqual(out.pop(), "mom")
@ -167,7 +174,7 @@ class TestTorchbind(JitTestCase):
out, result = scripted() out, result = scripted()
self.assertEqual(result, 10) self.assertEqual(result, 10)
with self.assertRaisesRegex(RuntimeError, 'can\'t set attribute'): with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
out.y = 5 out.y = 5
def foo_not_setter(): def foo_not_setter():
@ -177,9 +184,11 @@ class TestTorchbind(JitTestCase):
# getY method intentionally adds 4 to x # getY method intentionally adds 4 to x
return fooGetterSetter.y return fooGetterSetter.y
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
'Tried to set read-only attribute: y', RuntimeError,
'fooGetterSetter.y = old + 4'): "Tried to set read-only attribute: y",
"fooGetterSetter.y = old + 4",
):
scripted = torch.jit.script(foo_not_setter) scripted = torch.jit.script(foo_not_setter)
def test_torchbind_def_property_readwrite(self): def test_torchbind_def_property_readwrite(self):
@ -196,9 +205,9 @@ class TestTorchbind(JitTestCase):
fooReadWrite.y = 5 fooReadWrite.y = 5
return fooReadWrite return fooReadWrite
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
'Tried to set read-only attribute: y', RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
'fooReadWrite.y = 5'): ):
scripted = torch.jit.script(foo_readwrite_error) scripted = torch.jit.script(foo_readwrite_error)
def test_torchbind_take_instance_as_method_arg(self): def test_torchbind_take_instance_as_method_arg(self):
@ -250,7 +259,9 @@ class TestTorchbind(JitTestCase):
return self.foo_mod.info() return self.foo_mod.info()
def to_ivalue(self): def to_ivalue(self):
torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1) torchbind_model = torch.classes._TorchScriptTesting._Foo(
self.foo_mod.info(), 1
)
return FooBar(torchbind_model) return FooBar(torchbind_model)
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3)) inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
@ -338,7 +349,7 @@ class TestTorchbind(JitTestCase):
self.assertEqual(torch.zeros(4, 4), traced()) self.assertEqual(torch.zeros(4, 4), traced())
def test_torchbind_pass_wrong_type(self): def test_torchbind_pass_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, 'but instead found type \'Tensor\''): with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4)) torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
def test_torchbind_tracing_nested(self): def test_torchbind_tracing_nested(self):
@ -368,12 +379,15 @@ class TestTorchbind(JitTestCase):
self.assertEqual(nt_loaded.pop(), exp) self.assertEqual(nt_loaded.pop(), exp)
def test_torchbind_instantiate_missing_class(self): def test_torchbind_instantiate_missing_class(self):
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'): with self.assertRaisesRegex(
RuntimeError,
"Tried to instantiate class 'foo.IDontExist', but it does not exist!",
):
torch.classes.foo.IDontExist(3, 4, 5) torch.classes.foo.IDontExist(3, 4, 5)
def test_torchbind_optional_explicit_attr(self): def test_torchbind_optional_explicit_attr(self):
class TorchBindOptionalExplicitAttr(torch.nn.Module): class TorchBindOptionalExplicitAttr(torch.nn.Module):
foo : Optional[torch.classes._TorchScriptTesting._StackString] foo: Optional[torch.classes._TorchScriptTesting._StackString]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -384,13 +398,13 @@ class TestTorchbind(JitTestCase):
if foo_obj is not None: if foo_obj is not None:
return foo_obj.pop() return foo_obj.pop()
else: else:
return '<None>' return "<None>"
mod = TorchBindOptionalExplicitAttr() mod = TorchBindOptionalExplicitAttr()
scripted = torch.jit.script(mod) scripted = torch.jit.script(mod)
def test_torchbind_no_init(self): def test_torchbind_no_init(self):
with self.assertRaisesRegex(RuntimeError, 'torch::init'): with self.assertRaisesRegex(RuntimeError, "torch::init"):
x = torch.classes._TorchScriptTesting._NoInit() x = torch.classes._TorchScriptTesting._NoInit()
def test_profiler_custom_op(self): def test_profiler_custom_op(self):
@ -401,17 +415,17 @@ class TestTorchbind(JitTestCase):
found_event = False found_event = False
for e in prof.function_events: for e in prof.function_events:
if e.name == '_TorchScriptTesting::take_an_instance': if e.name == "_TorchScriptTesting::take_an_instance":
found_event = True found_event = True
self.assertTrue(found_event) self.assertTrue(found_event)
def test_torchbind_getattr(self): def test_torchbind_getattr(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"]) foo = torch.classes._TorchScriptTesting._StackString(["test"])
self.assertEqual(None, getattr(foo, 'bar', None)) self.assertEqual(None, getattr(foo, "bar", None))
def test_torchbind_attr_exception(self): def test_torchbind_attr_exception(self):
foo = torch.classes._TorchScriptTesting._StackString(["test"]) foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'): with self.assertRaisesRegex(AttributeError, "does not have a field"):
foo.bar foo.bar
def test_lambda_as_constructor(self): def test_lambda_as_constructor(self):

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,24 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import io
import os import os
import sys import sys
import io
import torch import torch
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import suppress_warnings from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestTypeSharing(JitTestCase): class TestTypeSharing(JitTestCase):
def assertSameType(self, m1, m2): def assertSameType(self, m1, m2):
@ -42,6 +45,7 @@ class TestTypeSharing(JitTestCase):
def forward(self, x): def forward(self, x):
return x return x
a = torch.rand(2, 3) a = torch.rand(2, 3)
b = torch.rand(2, 3) b = torch.rand(2, 3)
c = torch.rand(2, 3) c = torch.rand(2, 3)
@ -53,6 +57,7 @@ class TestTypeSharing(JitTestCase):
""" """
Types should be shared even if attribute values differ Types should be shared even if attribute values differ
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, a, b, c): def __init__(self, a, b, c):
super().__init__() super().__init__()
@ -62,6 +67,7 @@ class TestTypeSharing(JitTestCase):
def forward(self, x): def forward(self, x):
return x return x
a = torch.rand(2, 3) a = torch.rand(2, 3)
b = torch.rand(2, 3) b = torch.rand(2, 3)
c = torch.rand(2, 3) c = torch.rand(2, 3)
@ -73,6 +79,7 @@ class TestTypeSharing(JitTestCase):
""" """
Types should be shared for identical constant values, and different for different constant values Types should be shared for identical constant values, and different for different constant values
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
__constants__ = ["const"] __constants__ = ["const"]
@ -111,6 +118,7 @@ class TestTypeSharing(JitTestCase):
""" """
If submodules differ, the types should differ. If submodules differ, the types should differ.
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2): def __init__(self, in1, out1, in2, out2):
super().__init__() super().__init__()
@ -137,6 +145,7 @@ class TestTypeSharing(JitTestCase):
The same module with an `foo` as a parameter vs. attribute shouldn't The same module with an `foo` as a parameter vs. attribute shouldn't
share types share types
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, foo): def __init__(self, foo):
super().__init__() super().__init__()
@ -156,6 +165,7 @@ class TestTypeSharing(JitTestCase):
Even if everything about the module is the same, different originating Even if everything about the module is the same, different originating
classes should prevent type sharing. classes should prevent type sharing.
""" """
class A(torch.nn.Module): class A(torch.nn.Module):
__constants__ = ["const"] __constants__ = ["const"]
@ -192,6 +202,7 @@ class TestTypeSharing(JitTestCase):
""" """
Mutating the value of an attribute should not change type sharing Mutating the value of an attribute should not change type sharing
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2): def __init__(self, in1, out1, in2, out2):
super().__init__() super().__init__()
@ -214,6 +225,7 @@ class TestTypeSharing(JitTestCase):
""" """
Assigning a new (python-only) attribute should not change type sharing Assigning a new (python-only) attribute should not change type sharing
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self, in1, out1, in2, out2): def __init__(self, in1, out1, in2, out2):
super().__init__() super().__init__()
@ -244,6 +256,7 @@ class TestTypeSharing(JitTestCase):
""" """
Attributes whose type cannot be inferred should fail cleanly with nice hints Attributes whose type cannot be inferred should fail cleanly with nice hints
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -255,15 +268,16 @@ class TestTypeSharing(JitTestCase):
return self.foo return self.foo
m = M() m = M()
with self.assertRaisesRegexWithHighlight(RuntimeError, with self.assertRaisesRegexWithHighlight(
"failed to convert Python type", RuntimeError, "failed to convert Python type", "self.foo"
"self.foo"): ):
torch.jit.script(m) torch.jit.script(m)
def test_script_function_attribute_different(self): def test_script_function_attribute_different(self):
""" """
Different functions passed in should lead to different types Different functions passed in should lead to different types
""" """
@torch.jit.script @torch.jit.script
def fn1(x): def fn1(x):
return x + x return x + x
@ -317,6 +331,7 @@ class TestTypeSharing(JitTestCase):
""" """
Same functions passed in should lead to same types Same functions passed in should lead to same types
""" """
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
return x + x return x + x
@ -338,6 +353,7 @@ class TestTypeSharing(JitTestCase):
""" """
Different functions passed in should lead to different types Different functions passed in should lead to different types
""" """
def fn1(x): def fn1(x):
return x + x return x + x
@ -361,6 +377,7 @@ class TestTypeSharing(JitTestCase):
""" """
Same functions passed in should lead to same types Same functions passed in should lead to same types
""" """
def fn(x): def fn(x):
return x + x return x + x
@ -383,6 +400,7 @@ class TestTypeSharing(JitTestCase):
Since we can't guarantee that methods are the same between different Since we can't guarantee that methods are the same between different
trace runs, tracing must always generate a unique type. trace runs, tracing must always generate a unique type.
""" """
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, x, y): def forward(self, x, y):
if x.sum() > y.sum(): if x.sum() > y.sum():
@ -429,8 +447,8 @@ class TestTypeSharing(JitTestCase):
def forward(self, x): def forward(self, x):
return self.traced(x) return self.traced(x)
a = M((torch.ones(1), )) a = M((torch.ones(1),))
b = M((torch.zeros(1), )) b = M((torch.zeros(1),))
self.assertDifferentType(a, b) self.assertDifferentType(a, b)
def test_loaded_modules_work(self): def test_loaded_modules_work(self):
@ -465,7 +483,6 @@ class TestTypeSharing(JitTestCase):
buffer.seek(0) buffer.seek(0)
return torch.jit.script(Wrapper(torch.jit.load(buffer))) return torch.jit.script(Wrapper(torch.jit.load(buffer)))
a = package(AB()) a = package(AB())
a() a()
b = package(A()) b = package(A())
@ -476,6 +493,7 @@ class TestTypeSharing(JitTestCase):
We should be able to differentiate between two ModuleDict instances We should be able to differentiate between two ModuleDict instances
that have different keys but the same value types. that have different keys but the same value types.
""" """
class A(torch.nn.Module): class A(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x return x
@ -488,9 +506,9 @@ class TestTypeSharing(JitTestCase):
def forward(self, x): def forward(self, x):
return x return x
a = Foo({'foo': A()}) a = Foo({"foo": A()})
b = Foo({'bar': A()}) b = Foo({"bar": A()})
c = Foo({'bar': A()}) c = Foo({"bar": A()})
self.assertDifferentType(a, b) self.assertDifferentType(a, b)
self.assertSameType(b, c) self.assertSameType(b, c)
@ -500,13 +518,16 @@ class TestTypeSharing(JitTestCase):
subclass that defines methods in its __init__ are not subclass that defines methods in its __init__ are not
shared. shared.
""" """
class A(torch.jit.ScriptModule): class A(torch.jit.ScriptModule):
def __init__(self, val): def __init__(self, val):
super().__init__() super().__init__()
self.define(f""" self.define(
f"""
def forward(self) -> int: def forward(self) -> int:
return {val} return {val}
""") """
)
one = A(1) one = A(1)
two = A(2) two = A(2)
@ -518,6 +539,7 @@ class TestTypeSharing(JitTestCase):
""" """
Test that type sharing can be disabled. Test that type sharing can be disabled.
""" """
class A(torch.nn.Module): class A(torch.nn.Module):
def __init__(self, sub): def __init__(self, sub):
super().__init__() super().__init__()
@ -555,6 +577,7 @@ class TestTypeSharing(JitTestCase):
Test that types are shared if the exclusion of their Test that types are shared if the exclusion of their
ignored attributes makes them equal. ignored attributes makes them equal.
""" """
class A(torch.nn.Module): class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"] __jit_ignored_attributes__ = ["a"]
@ -579,6 +602,7 @@ class TestTypeSharing(JitTestCase):
Test that types are not shared if the exclusion of their Test that types are not shared if the exclusion of their
ignored attributes makes them not equal. ignored attributes makes them not equal.
""" """
class A(torch.nn.Module): class A(torch.nn.Module):
__jit_ignored_attributes__ = ["a"] __jit_ignored_attributes__ = ["a"]

View File

@ -1,26 +1,31 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
from collections import namedtuple
from typing import Dict, Iterator, List, Optional, Tuple
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing import FileCheck
from textwrap import dedent
from jit.test_module_interface import TestModuleInterface # noqa: F401
import inspect import inspect
import os import os
import sys import sys
from collections import namedtuple
from textwrap import dedent
from typing import Dict, Iterator, List, Optional, Tuple
import torch import torch
import torch.testing._internal.jit_utils import torch.testing._internal.jit_utils
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import JitTestCase
from jit.test_module_interface import TestModuleInterface # noqa: F401
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestTypesAndAnnotation(JitTestCase): class TestTypesAndAnnotation(JitTestCase):
def test_pep585_type(self): def test_pep585_type(self):
@ -30,7 +35,7 @@ class TestTypesAndAnnotation(JitTestCase):
xl: list[tuple[torch.Tensor]] = [] xl: list[tuple[torch.Tensor]] = []
xd: dict[str, int] = {} xd: dict[str, int] = {}
xl.append((x,)) xl.append((x,))
xd['foo'] = 1 xd["foo"] = 1
return xl.pop(), xd return xl.pop(), xd
self.checkScript(fn, [torch.randn(2, 2)]) self.checkScript(fn, [torch.randn(2, 2)])
@ -47,7 +52,7 @@ class TestTypesAndAnnotation(JitTestCase):
self.checkScript(fn, [torch.randn(2, 2)]) self.checkScript(fn, [torch.randn(2, 2)])
GG = namedtuple('GG', ['f', 'g']) GG = namedtuple("GG", ["f", "g"])
class Foo(torch.nn.Module): class Foo(torch.nn.Module):
@torch.jit.ignore @torch.jit.ignore
@ -77,13 +82,17 @@ class TestTypesAndAnnotation(JitTestCase):
return x + 10 return x + 10
class M(torch.nn.Module): class M(torch.nn.Module):
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor: def forward(
self, in_batch: Dict[str, Optional[torch.Tensor]]
) -> torch.Tensor:
self.dropout_modality(in_batch) self.dropout_modality(in_batch)
fn(in_batch) fn(in_batch)
return torch.tensor(1) return torch.tensor(1)
@torch.jit.ignore @torch.jit.ignore
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]: def dropout_modality(
self, in_batch: Dict[str, Optional[torch.Tensor]]
) -> Dict[str, Optional[torch.Tensor]]:
return in_batch return in_batch
sm = torch.jit.script(M()) sm = torch.jit.script(M())
@ -111,16 +120,17 @@ class TestTypesAndAnnotation(JitTestCase):
return my_arg + 10 return my_arg + 10
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"): with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
@torch.jit.script @torch.jit.script
def other_fn(x): def other_fn(x):
return fn('2') return fn("2")
def test_type_annotate_py3(self): def test_type_annotate_py3(self):
def fn(): def fn():
a : List[int] = [] a: List[int] = []
b : torch.Tensor = torch.ones(2, 2) b: torch.Tensor = torch.ones(2, 2)
c : Optional[torch.Tensor] = None c: Optional[torch.Tensor] = None
d : Optional[torch.Tensor] = torch.ones(3, 4) d: Optional[torch.Tensor] = torch.ones(3, 4)
for _ in range(10): for _ in range(10):
a.append(4) a.append(4)
c = torch.ones(2, 2) c = torch.ones(2, 2)
@ -130,66 +140,88 @@ class TestTypesAndAnnotation(JitTestCase):
self.checkScript(fn, ()) self.checkScript(fn, ())
def wrong_type(): def wrong_type():
wrong : List[int] = [0.5] wrong: List[int] = [0.5]
return wrong return wrong
with self.assertRaisesRegex(RuntimeError, "List type annotation" with self.assertRaisesRegex(
r" `List\[int\]` did not match the " RuntimeError,
"types of the given list elements"): "List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements",
):
torch.jit.script(wrong_type) torch.jit.script(wrong_type)
def test_optional_no_element_type_annotation(self): def test_optional_no_element_type_annotation(self):
""" """
Test that using an optional with no contained types produces an error. Test that using an optional with no contained types produces an error.
""" """
def fn_with_comment(x: torch.Tensor) -> Optional: def fn_with_comment(x: torch.Tensor) -> Optional:
return (x, x) return (x, x)
def annotated_fn(x: torch.Tensor) -> Optional: def annotated_fn(x: torch.Tensor) -> Optional:
return (x, x) return (x, x)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
cu = torch.jit.CompilationUnit() cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(fn_with_comment))) cu.define(dedent(inspect.getsource(fn_with_comment)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
cu = torch.jit.CompilationUnit() cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(annotated_fn))) cu.define(dedent(inspect.getsource(annotated_fn)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
torch.jit.script(fn_with_comment) torch.jit.script(fn_with_comment)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Optional without a contained type"
):
torch.jit.script(annotated_fn) torch.jit.script(annotated_fn)
def test_tuple_no_element_type_annotation(self): def test_tuple_no_element_type_annotation(self):
""" """
Test that using a tuple with no contained types produces an error. Test that using a tuple with no contained types produces an error.
""" """
def fn_with_comment(x: torch.Tensor) -> Tuple: def fn_with_comment(x: torch.Tensor) -> Tuple:
return (x, x) return (x, x)
def annotated_fn(x: torch.Tensor) -> Tuple: def annotated_fn(x: torch.Tensor) -> Tuple:
return (x, x) return (x, x)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
cu = torch.jit.CompilationUnit() cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(fn_with_comment))) cu.define(dedent(inspect.getsource(fn_with_comment)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
cu = torch.jit.CompilationUnit() cu = torch.jit.CompilationUnit()
cu.define(dedent(inspect.getsource(annotated_fn))) cu.define(dedent(inspect.getsource(annotated_fn)))
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
torch.jit.script(fn_with_comment) torch.jit.script(fn_with_comment)
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"): with self.assertRaisesRegex(
RuntimeError, r"Attempted to use Tuple without a contained type"
):
torch.jit.script(annotated_fn) torch.jit.script(annotated_fn)
def test_ignoring_module_attributes(self): def test_ignoring_module_attributes(self):
""" """
Test that module attributes can be ignored. Test that module attributes can be ignored.
""" """
class Sub(torch.nn.Module): class Sub(torch.nn.Module):
def forward(self, a: int) -> int: def forward(self, a: int) -> int:
return sum([a]) return sum([a])
@ -229,10 +261,11 @@ class TestTypesAndAnnotation(JitTestCase):
mod = ModuleUsesIgnoredAttr(1) mod = ModuleUsesIgnoredAttr(1)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"attribute was ignored during compilation", "self.sub"
):
scripted_mod = torch.jit.script(mod) scripted_mod = torch.jit.script(mod)
def test_ignoring_fn_with_nonscriptable_types(self): def test_ignoring_fn_with_nonscriptable_types(self):
class CFX: class CFX:
def __init__(self, a: List[torch.Tensor]) -> None: def __init__(self, a: List[torch.Tensor]) -> None:
@ -246,7 +279,9 @@ class TestTypesAndAnnotation(JitTestCase):
return iter(self.a) return iter(self.a)
@torch.jit._drop @torch.jit._drop
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument: def __fx_create_arg__(
self, tracer: torch.fx.Tracer
) -> torch.fx.node.Argument:
# torch.fx classes are not scriptable # torch.fx classes are not scriptable
return tracer.create_node( return tracer.create_node(
"call_function", "call_function",
@ -257,35 +292,36 @@ class TestTypesAndAnnotation(JitTestCase):
torch.jit.script(CFX) torch.jit.script(CFX)
def test_unimported_type_resolution(self): def test_unimported_type_resolution(self):
# verify fallback from the python resolver to the c++ resolver # verify fallback from the python resolver to the c++ resolver
@ torch.jit.script @torch.jit.script
def fn(x): def fn(x):
# type: (number) -> number # type: (number) -> number
return x + 1 return x + 1
FileCheck().check('Scalar').run(fn.graph) FileCheck().check("Scalar").run(fn.graph)
def test_parser_bug(self): def test_parser_bug(self):
def parser_bug(o: Optional[torch.Tensor]): def parser_bug(o: Optional[torch.Tensor]):
pass pass
def test_mismatched_annotation(self): def test_mismatched_annotation(self):
with self.assertRaisesRegex(RuntimeError, 'annotated with type'): with self.assertRaisesRegex(RuntimeError, "annotated with type"):
@torch.jit.script @torch.jit.script
def foo(): def foo():
x : str = 4 x: str = 4
return x return x
def test_reannotate(self): def test_reannotate(self):
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'): with self.assertRaisesRegex(RuntimeError, "declare and annotate"):
@torch.jit.script @torch.jit.script
def foo(): def foo():
x = 5 x = 5
if 1 == 1: if 1 == 1:
x : Optional[int] = 7 x: Optional[int] = 7
def test_annotate_outside_init(self): def test_annotate_outside_init(self):
msg = "annotations on instance attributes must be declared in __init__" msg = "annotations on instance attributes must be declared in __init__"
@ -293,6 +329,7 @@ class TestTypesAndAnnotation(JitTestCase):
# Simple case # Simple case
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script @torch.jit.script
class BadModule: class BadModule:
def __init__(self, x: int): def __init__(self, x: int):
@ -303,6 +340,7 @@ class TestTypesAndAnnotation(JitTestCase):
# Type annotation in a loop # Type annotation in a loop
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight): with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
@torch.jit.script @torch.jit.script
class BadModuleLoop: class BadModuleLoop:
def __init__(self, x: int): def __init__(self, x: int):
@ -324,8 +362,10 @@ class TestTypesAndAnnotation(JitTestCase):
def test_inferred_type_error_message(self): def test_inferred_type_error_message(self):
inferred_type = torch._C.InferredType("ErrorReason") inferred_type = torch._C.InferredType("ErrorReason")
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
"Tried to get the type from an InferredType but the type is null."): RuntimeError,
"Tried to get the type from an InferredType but the type is null.",
):
t = inferred_type.type() t = inferred_type.type()
with self.assertRaisesRegex(RuntimeError, "ErrorReason"): with self.assertRaisesRegex(RuntimeError, "ErrorReason"):

View File

@ -2,12 +2,12 @@
import os import os
import sys import sys
from collections import namedtuple
from typing import Dict, List, NamedTuple, Tuple
import torch import torch
from torch.testing._internal.jit_utils import JitTestCase, make_global
from torch.testing._internal.common_utils import IS_WINDOWS from torch.testing._internal.common_utils import IS_WINDOWS
from collections import namedtuple from torch.testing._internal.jit_utils import JitTestCase, make_global
from typing import List, Tuple, Dict, NamedTuple
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@ -20,14 +20,15 @@ if __name__ == "__main__":
"instead." "instead."
) )
class TestTyping(JitTestCase): class TestTyping(JitTestCase):
def test_dict_in_not_in(self): def test_dict_in_not_in(self):
def test_in_dict(x): def test_in_dict(x):
# type: (Dict[str, int]) -> bool # type: (Dict[str, int]) -> bool
return 'hi' in x return "hi" in x
self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},)) self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},))
self.checkScript(test_in_dict, ({'bye': 3},)) self.checkScript(test_in_dict, ({"bye": 3},))
# Check evaluation order # Check evaluation order
@torch.jit.script @torch.jit.script
@ -57,8 +58,8 @@ class TestTyping(JitTestCase):
else: else:
return True return True
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, )) self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},))
self.checkScript(test_not_in_dict, ({"world": 2}, )) self.checkScript(test_not_in_dict, ({"world": 2},))
def test_dict_tensor_key(a, t): def test_dict_tensor_key(a, t):
# type: (Dict[Tensor, int], Tensor) -> bool # type: (Dict[Tensor, int], Tensor) -> bool
@ -80,9 +81,12 @@ class TestTyping(JitTestCase):
l: List[int] = [1, 2, "foo", 3] l: List[int] = [1, 2, "foo", 3]
return l return l
with self.assertRaisesRegex(RuntimeError, "List type annotation" with self.assertRaisesRegex(
r" `List\[int\]` did not match the " RuntimeError,
"types of the given list elements"): "List type annotation"
r" `List\[int\]` did not match the "
"types of the given list elements",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_dict_type_refinement_annotation_key_mismatch(self): def test_dict_type_refinement_annotation_key_mismatch(self):
@ -92,10 +96,13 @@ class TestTyping(JitTestCase):
d: Dict[int, str] = dict(zip(l1, l2)) d: Dict[int, str] = dict(zip(l1, l2))
return d return d
with self.assertRaisesRegex(RuntimeError, "Dicts may only " with self.assertRaisesRegex(
"contain homogeneous keys, but the " RuntimeError,
"type of the first generated key " "Dicts may only "
r"was Union\[int, str\]"): "contain homogeneous keys, but the "
"type of the first generated key "
r"was Union\[int, str\]",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_dict_type_refinement_annotation_value_mismatch(self): def test_dict_type_refinement_annotation_value_mismatch(self):
@ -105,28 +112,36 @@ class TestTyping(JitTestCase):
d: Dict[str, int] = dict(zip(l1, l2)) d: Dict[str, int] = dict(zip(l1, l2))
return d return d
with self.assertRaisesRegex(RuntimeError, "Dict type annotation" with self.assertRaisesRegex(
r" `Dict\[str, int\]` did not match" RuntimeError,
" the type of an actual value type" "Dict type annotation"
r" `Union\[int, str\]`"): r" `Dict\[str, int\]` did not match"
" the type of an actual value type"
r" `Union\[int, str\]`",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_dict_invalid_annotations(self): def test_dict_invalid_annotations(self):
# Check for invalid value type annotation # Check for invalid value type annotation
def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]): def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
return return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"): with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_value_type) torch.jit.script(wrong_value_type)
# Check for invalid key type annotation # Check for invalid key type annotation
def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
return return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"): with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_key_type) torch.jit.script(wrong_key_type)
# Check for invalid key and value type annotation # Check for invalid key and value type annotation
def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]): def wrong_key_value_type(
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
):
return return
with self.assertRaisesRegex(ValueError, "Unknown type annotation"): with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
torch.jit.script(wrong_key_value_type) torch.jit.script(wrong_key_value_type)
@ -138,13 +153,16 @@ class TestTyping(JitTestCase):
_, y = t2 _, y = t2
return x + y return x + y
t = torch.randn(2, 2), (1, torch.randn(2, 2)), t = (
torch.randn(2, 2),
(1, torch.randn(2, 2)),
)
f(t, "hi") f(t, "hi")
graph = f.graph_for(t, "hi") graph = f.graph_for(t, "hi")
input_types = list(next(graph.inputs()).type().elements()) input_types = list(next(graph.inputs()).type().elements())
w = input_types[0] w = input_types[0]
self.assertEqual(input_types[0].kind(), 'TensorType') self.assertEqual(input_types[0].kind(), "TensorType")
self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType') self.assertEqual(input_types[1].elements()[1].kind(), "TensorType")
def test_tuple_io(self): def test_tuple_io(self):
def stuff(x): def stuff(x):
@ -165,8 +183,7 @@ class TestTyping(JitTestCase):
def foo(): def foo():
return tuple(1, 2) return tuple(1, 2)
self.checkScriptRaisesRegex(foo, (), Exception, self.checkScriptRaisesRegex(foo, (), Exception, "1 argument")
"1 argument")
def cant_infer_size(): def cant_infer_size():
return tuple([1, 2, 3]) # noqa: C409 return tuple([1, 2, 3]) # noqa: C409
@ -179,12 +196,14 @@ class TestTyping(JitTestCase):
# type: (int) -> Tuple[Tensor, Tensor] # type: (int) -> Tuple[Tensor, Tensor]
a = (torch.ones(x), torch.zeros(x)) a = (torch.ones(x), torch.zeros(x))
return a return a
self.checkScript(stuff2, (3,)) self.checkScript(stuff2, (3,))
def test_list_io(self): def test_list_io(self):
def stuff3(x): def stuff3(x):
# type: (List[int]) -> Tuple[Tensor, List[int]] # type: (List[int]) -> Tuple[Tensor, List[int]]
return torch.ones(x), x return torch.ones(x), x
self.checkScript(stuff3, ([3, 2],)) self.checkScript(stuff3, ([3, 2],))
def test_bool_list_io(self): def test_bool_list_io(self):
@ -203,6 +222,7 @@ class TestTyping(JitTestCase):
# type: (Tuple[int, List[List[int]]]) -> int # type: (Tuple[int, List[List[int]]]) -> int
x, y = z x, y = z
return y[0][1] return y[0][1]
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),)) self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
def test_list_sum(self): def test_list_sum(self):
@ -215,12 +235,12 @@ class TestTyping(JitTestCase):
def fn2(x: List[bool]): def fn2(x: List[bool]):
return sum(x) return sum(x)
self.checkScript(fn, ([1, 2, 3], )) self.checkScript(fn, ([1, 2, 3],))
self.checkScript(fn1, ([1.0, 2.0, 3.0], )) self.checkScript(fn1, ([1.0, 2.0, 3.0],))
self.checkScript(fn1, ([1, 2.8, 3], )) self.checkScript(fn1, ([1, 2.8, 3],))
self.checkScript(fn2, ([True, False, False], )) self.checkScript(fn2, ([True, False, False],))
self.checkScript(fn2, ([False, False, False], )) self.checkScript(fn2, ([False, False, False],))
self.checkScript(fn2, ([0, 1, 1, 0], )) self.checkScript(fn2, ([0, 1, 1, 0],))
def test_list_unification(self): def test_list_unification(self):
def fn(): def fn():
@ -254,7 +274,6 @@ class TestTyping(JitTestCase):
self.checkScript(self.get_sum_list_fn(), ([1],)) self.checkScript(self.get_sum_list_fn(), ([1],))
def test_sum_list_literal(self): def test_sum_list_literal(self):
def sum_list(): def sum_list():
# type: () -> int # type: () -> int
sum = 0 sum = 0
@ -266,8 +285,8 @@ class TestTyping(JitTestCase):
self.checkScript(sum_list, ()) self.checkScript(sum_list, ())
def test_sum_list_wrong_type(self): def test_sum_list_wrong_type(self):
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
@torch.jit.script @torch.jit.script
def sum_list(a): def sum_list(a):
# type: (int) -> int # type: (int) -> int
@ -280,14 +299,18 @@ class TestTyping(JitTestCase):
sum_list(1) sum_list(1)
def test_list_iterables(self): def test_list_iterables(self):
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): with self.assertRaisesRegex(
cu = torch.jit.CompilationUnit(''' RuntimeError, "List of iterables is not supported currently"
):
cu = torch.jit.CompilationUnit(
"""
def list_iterables(x): def list_iterables(x):
for i, j in [2, 3, 4], [5, 6, 7]: for i, j in [2, 3, 4], [5, 6, 7]:
x += i x += i
x += j x += j
return x return x
''') """
)
def test_for_in_string(self): def test_for_in_string(self):
def test_strings(x): def test_strings(x):
@ -352,36 +375,43 @@ class TestTyping(JitTestCase):
def test_dict_comprehension(self): def test_dict_comprehension(self):
def fn(): def fn():
return {i : chr(i + 65) for i in range(4)} return {i: chr(i + 65) for i in range(4)}
self.checkScript(fn, ()) self.checkScript(fn, ())
def test_dict_comprehension_with_type_annotation(self): def test_dict_comprehension_with_type_annotation(self):
def fn(): def fn():
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)} d: Dict[int, str] = {i: chr(i + 65) for i in range(4)}
return d return d
self.checkScript(fn, ()) self.checkScript(fn, ())
with self.assertRaisesRegex(RuntimeError, ""): with self.assertRaisesRegex(RuntimeError, ""):
with self.assertRaisesRegex(AssertionError, "Expected Dict " with self.assertRaisesRegex(
"type annotation for dict " AssertionError,
"comprehension, found " "Expected Dict "
"Tuple[int, str]"): "type annotation for dict "
"comprehension, found "
"Tuple[int, str]",
):
@torch.jit.script @torch.jit.script
def fn(): def fn():
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)} d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)}
return d return d
def test_dict_comprehension_scope(self): def test_dict_comprehension_scope(self):
def comprehension_can_access_outer_scope_variables(): def comprehension_can_access_outer_scope_variables():
lst = ["foo", "bar", "baz"] lst = ["foo", "bar", "baz"]
return {l : len(l) for l in lst} return {l: len(l) for l in lst}
self.checkScript(comprehension_can_access_outer_scope_variables, ()) self.checkScript(comprehension_can_access_outer_scope_variables, ())
with self.assertRaisesRegex(RuntimeError, "undefined value i"): with self.assertRaisesRegex(RuntimeError, "undefined value i"):
@torch.jit.script @torch.jit.script
def outer_scope_cannot_access_comprehension_variables(): def outer_scope_cannot_access_comprehension_variables():
d = {i : chr(i + 65) for i in range(4)} d = {i: chr(i + 65) for i in range(4)}
i = i + 1 # noqa: F821 i = i + 1 # noqa: F821
def test_for_tuple_assign(self): def test_for_tuple_assign(self):
@ -402,22 +432,28 @@ class TestTyping(JitTestCase):
sum += a[1] sum += a[1]
return sum return sum
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),))
def test_single_starred_lhs(self): def test_single_starred_lhs(self):
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' with self.assertRaisesRegex(
' of another non-starred expression'): RuntimeError,
cu = torch.jit.CompilationUnit(''' "A Starred expression may only appear on the lhs within the presence"
" of another non-starred expression",
):
cu = torch.jit.CompilationUnit(
"""
def single_starred_lhs(x): def single_starred_lhs(x):
a = (x, x, x) a = (x, x, x)
*b, = a *b, = a
return b return b
''') """
)
def test_singleton_tuple_unpack(self): def test_singleton_tuple_unpack(self):
def foo(a): def foo(a):
b, = (a,) (b,) = (a,)
return b + 1 return b + 1
self.checkScript(foo, (torch.rand(3),)) self.checkScript(foo, (torch.rand(3),))
def test_tuple_assignments(self): def test_tuple_assignments(self):
@ -441,7 +477,9 @@ class TestTyping(JitTestCase):
a[i], (x[i], b) = 1, (2, 3) a[i], (x[i], b) = 1, (2, 3)
return a[i] + 1, x + 5, b return a[i] + 1, x + 5, b
self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) self.checkScript(
subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)
)
def star_tuple_assign(): def star_tuple_assign():
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
@ -455,7 +493,7 @@ class TestTyping(JitTestCase):
a[0] += 1 a[0] += 1
return a return a
with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"):
scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
def test_multiple_assign(self): def test_multiple_assign(self):
@ -505,7 +543,6 @@ class TestTyping(JitTestCase):
# type: (Optional[int]) -> int # type: (Optional[int]) -> int
return torch.jit._unwrap_optional(x) return torch.jit._unwrap_optional(x)
@torch.jit.script @torch.jit.script
def fn(x): def fn(x):
# type: (int) -> int # type: (int) -> int
@ -540,7 +577,7 @@ class TestTyping(JitTestCase):
# type: (Tuple[float, float]) -> int # type: (Tuple[float, float]) -> int
return opt_list(x) + broadcast_opt_list(x) return opt_list(x) + broadcast_opt_list(x)
self.assertEqual(opt_list_tuple_caller((2., 3.)), 4) self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4)
def test_optional_tuple(self): def test_optional_tuple(self):
def fn(x=None): def fn(x=None):
@ -556,10 +593,11 @@ class TestTyping(JitTestCase):
def test_namedtuple_redefine(self): def test_namedtuple_redefine(self):
global _1, _2 global _1, _2
_1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) _1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
_2 = namedtuple('GoogLeNetOutputs', ['different']) _2 = namedtuple("GoogLeNetOutputs", ["different"])
with self.assertRaisesRegex(RuntimeError, r"redefine"):
with self.assertRaisesRegex(RuntimeError, r'redefine'):
@torch.jit.script @torch.jit.script
def foo(x, y): def foo(x, y):
# type: (_1, _2) -> _1 # type: (_1, _2) -> _1
@ -567,7 +605,9 @@ class TestTyping(JitTestCase):
def test_namedtuple_py2(self): def test_namedtuple_py2(self):
global _GoogLeNetOutputs # see [local resolution in python] global _GoogLeNetOutputs # see [local resolution in python]
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) _GoogLeNetOutputs = namedtuple(
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
)
@torch.jit.script @torch.jit.script
def foo(x): def foo(x):
@ -575,22 +615,27 @@ class TestTyping(JitTestCase):
return x return x
vals = torch.rand(3), torch.rand(4), torch.rand(5) vals = torch.rand(3), torch.rand(4), torch.rand(5)
out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])) out = foo(
_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])
)
self.assertEqual(out.logits, vals[0]) self.assertEqual(out.logits, vals[0])
self.assertEqual(out.aux_logits2, vals[1]) self.assertEqual(out.aux_logits2, vals[1])
self.assertEqual(out.aux_logits1, vals[2]) self.assertEqual(out.aux_logits1, vals[2])
def test_namedtuple_good_error(self): def test_namedtuple_good_error(self):
global _GoogLeNetOutputs # see [local resolution in python] global _GoogLeNetOutputs # see [local resolution in python]
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) _GoogLeNetOutputs = namedtuple(
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
)
@torch.jit.script @torch.jit.script
def foo(x): def foo(x):
# type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
return x return x
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(
r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'): RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)"
):
out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5")) out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
def test_namedtuple_error_source_attribution(self): def test_namedtuple_error_source_attribution(self):

View File

@ -3,22 +3,25 @@
import io import io
import os import os
import sys import sys
import torch
from torch.testing import FileCheck
from enum import Enum from enum import Enum
from textwrap import dedent from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.testing._internal.jit_utils import JitTestCase, make_global
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUnion(JitTestCase): class TestUnion(JitTestCase):
""" """
@ -57,9 +60,12 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn) scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[float, int\] but " RuntimeError,
"instead found type str"): "Expected a member of"
r" Union\[float, int\] but "
"instead found type str",
):
scripted("1") scripted("1")
def test_union_with_collections(self): def test_union_with_collections(self):
@ -71,22 +77,31 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn) scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[List\[int\], Dict\[str, " RuntimeError,
r"int\]\] but instead found type " "Expected a member of"
r"Dict\[str, str\]"): r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"Dict\[str, str\]",
):
scripted({"foo": "bar", "baz": "qux"}) scripted({"foo": "bar", "baz": "qux"})
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[List\[int\], Dict\[str, " RuntimeError,
r"int\]\] but instead found type " "Expected a member of"
r"List\[str\]"): r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
r"List\[str\]",
):
scripted(["foo", "bar", "baz"]) scripted(["foo", "bar", "baz"])
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[List\[int\], Dict\[str, " RuntimeError,
r"int\]\] but instead found type " "Expected a member of"
"str"): r" Union\[List\[int\], Dict\[str, "
r"int\]\] but instead found type "
"str",
):
scripted("1") scripted("1")
def test_union_with_enum(self): def test_union_with_enum(self):
@ -104,16 +119,18 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn) scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[__torch__.jit.test_union." RuntimeError,
r"Color, str\] but instead found " "Expected a member of"
"type int"): r" Union\[__torch__.jit.test_union."
r"Color, str\] but instead found "
"type int",
):
scripted(1) scripted(1)
def test_union_in_class_constructor(self): def test_union_in_class_constructor(self):
@torch.jit.script # noqa: B903 @torch.jit.script # noqa: B903
class A: # noqa: B903 class A: # noqa: B903
def __init__(self, x: Union[int, str]) -> None: def __init__(self, x: Union[int, str]) -> None:
self.x = x self.x = x
@ -125,9 +142,12 @@ class TestUnion(JitTestCase):
scripted = torch.jit.script(fn) scripted = torch.jit.script(fn)
with self.assertRaisesRegex(RuntimeError, "Expected a member of" with self.assertRaisesRegex(
r" Union\[int, str\] but instead " RuntimeError,
r"found type List\[str\]"): "Expected a member of"
r" Union\[int, str\] but instead "
r"found type List\[str\]",
):
scripted(["foo", "bar", "baz"]) scripted(["foo", "bar", "baz"])
def test_union_return_type(self): def test_union_return_type(self):
@ -171,7 +191,7 @@ class TestUnion(JitTestCase):
def test_union_variable_can_be_reassigned(self): def test_union_variable_can_be_reassigned(self):
@torch.jit.script @torch.jit.script
def aux1(i: int): def aux1(i: int):
return int(i ** 2) return int(i**2)
@torch.jit.script @torch.jit.script
def aux2(s: str): def aux2(s: str):
@ -225,8 +245,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : Union(float, int, str)") \ FileCheck().check("x : Union(float, int, str)").run(s)
.run(s)
def test_unions_of_a_single_argument_vanish(self): def test_unions_of_a_single_argument_vanish(self):
@torch.jit.script @torch.jit.script
@ -235,8 +254,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : int") \ FileCheck().check("x : int").run(s)
.run(s)
def test_union_redundant_arguments_are_skipped(self): def test_union_redundant_arguments_are_skipped(self):
@torch.jit.script @torch.jit.script
@ -245,8 +263,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : Union(int, str)") \ FileCheck().check("x : Union(int, str)").run(s)
.run(s)
def test_union_redundant_arguments_are_skipped_optional(self): def test_union_redundant_arguments_are_skipped_optional(self):
@torch.jit.script @torch.jit.script
@ -255,8 +272,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : Union(float, int, NoneType)") \ FileCheck().check("x : Union(float, int, NoneType)").run(s)
.run(s)
def test_union_redundant_arguments_are_skipped_subtyping(self): def test_union_redundant_arguments_are_skipped_subtyping(self):
@torch.jit.script @torch.jit.script
@ -265,8 +281,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : Union((int?, int), str)") \ FileCheck().check("x : Union((int?, int), str)").run(s)
.run(s)
def test_union_redundant_arguments_are_skipped_container(self): def test_union_redundant_arguments_are_skipped_container(self):
@torch.jit.script @torch.jit.script
@ -275,8 +290,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
FileCheck().check("x : Union(float[], str[])") \ FileCheck().check("x : Union(float[], str[])").run(s)
.run(s)
def test_union_argument_order_is_ignored(self): def test_union_argument_order_is_ignored(self):
@torch.jit.script @torch.jit.script
@ -288,8 +302,7 @@ class TestUnion(JitTestCase):
return "foo" return "foo"
for s in (fn1.graph, fn2.graph): for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int, str)") \ FileCheck().check("x : Union(int, str)").run(s)
.run(s)
def test_union_argument_order_is_ignored_container(self): def test_union_argument_order_is_ignored_container(self):
@torch.jit.script @torch.jit.script
@ -301,8 +314,7 @@ class TestUnion(JitTestCase):
return "foo" return "foo"
for s in (fn1.graph, fn2.graph): for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union(int[], str[])") \ FileCheck().check("x : Union(int[], str[])").run(s)
.run(s)
def test_union_T_None_is_equivalent_to_optional_T(self): def test_union_T_None_is_equivalent_to_optional_T(self):
@torch.jit.script @torch.jit.script
@ -366,9 +378,9 @@ class TestUnion(JitTestCase):
s = l.code s = l.code
FileCheck().check("Union[int, NoneType, str]") \ FileCheck().check("Union[int, NoneType, str]").check(
.check("Union[int, NoneType, str]") \ "Union[int, NoneType, str]"
.run(s) ).run(s)
def test_union_subclasses_larger_union(self): def test_union_subclasses_larger_union(self):
def fn() -> Union[int, str, torch.Tensor]: def fn() -> Union[int, str, torch.Tensor]:
@ -386,9 +398,12 @@ class TestUnion(JitTestCase):
x[1] = 2 x[1] = 2
return x[1] return x[1]
with self.assertRaisesRegex(RuntimeError, "only int, float, " with self.assertRaisesRegex(
"complex, Tensor, device and string keys " RuntimeError,
"are supported"): "only int, float, "
"complex, Tensor, device and string keys "
"are supported",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_union_as_dict_value(self): def test_union_as_dict_value(self):
@ -402,7 +417,6 @@ class TestUnion(JitTestCase):
def test_union_module_with_union_instance_variable(self): def test_union_module_with_union_instance_variable(self):
class M(torch.nn.Module): class M(torch.nn.Module):
x: Union[int, str] x: Union[int, str]
def __init__(self, x: Union[int, str]): def __init__(self, x: Union[int, str]):
@ -413,7 +427,12 @@ class TestUnion(JitTestCase):
self.x = y self.x = y
return self.x return self.x
self.checkModule(M(2,), (1,)) self.checkModule(
M(
2,
),
(1,),
)
self.checkModule(M("bar"), ("foo",)) self.checkModule(M("bar"), ("foo",))
def test_union_module_with_union_class_variable(self): def test_union_module_with_union_class_variable(self):
@ -508,9 +527,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
# Check that we don't have any branching statements # Check that we don't have any branching statements
FileCheck().check_not("block0()") \ FileCheck().check_not("block0()").check_not("block1()").run(s)
.check_not("block1()") \
.run(s)
def test_union_type_refinement_statically_true(self): def test_union_type_refinement_statically_true(self):
@torch.jit.script @torch.jit.script
@ -525,9 +542,7 @@ class TestUnion(JitTestCase):
s = fn.graph s = fn.graph
# Check that we don't have any branching statements # Check that we don't have any branching statements
FileCheck().check_not("block0()") \ FileCheck().check_not("block0()").check_not("block1()").run(s)
.check_not("block1()") \
.run(s)
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
def fn(x: Union[List[int], int]) -> int: def fn(x: Union[List[int], int]) -> int:
@ -556,7 +571,7 @@ class TestUnion(JitTestCase):
def test_union_type_refinement_internal_declaration(self): def test_union_type_refinement_internal_declaration(self):
def fn(flag: bool) -> str: def fn(flag: bool) -> str:
x: Union[int, str, None] = None x: Union[int, str, None] = None
if (flag): if flag:
y = "foo" y = "foo"
else: else:
y = 1 y = 1
@ -589,9 +604,12 @@ class TestUnion(JitTestCase):
else: else:
return "bar" return "bar"
with self.assertRaisesRegex(RuntimeError, "y is set to type str" with self.assertRaisesRegex(
" in the true branch and type int " RuntimeError,
"in the false branch"): "y is set to type str"
" in the true branch and type int "
"in the false branch",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_union_branching_does_not_widen_existing_inferred_type(self): def test_union_branching_does_not_widen_existing_inferred_type(self):
@ -606,9 +624,12 @@ class TestUnion(JitTestCase):
else: else:
return "baz" return "baz"
with self.assertRaisesRegex(RuntimeError, "previously had type " with self.assertRaisesRegex(
"str but is now being assigned to a" RuntimeError,
" value of type int"): "previously had type "
"str but is now being assigned to a"
" value of type int",
):
torch.jit.script(fn) torch.jit.script(fn)
def test_union_schema_matching_on_internal_type(self): def test_union_schema_matching_on_internal_type(self):
@ -645,8 +666,8 @@ class TestUnion(JitTestCase):
def test_union_memory_aliasing(self): def test_union_memory_aliasing(self):
def fn(): def fn():
x : List[torch.Tensor] = [] x: List[torch.Tensor] = []
z : List[Optional[List[torch.Tensor]]] = [] z: List[Optional[List[torch.Tensor]]] = []
z.append(x) z.append(x)
x_alias = z[0] x_alias = z[0]
if torch.jit.isinstance(x_alias, List[torch.Tensor]): if torch.jit.isinstance(x_alias, List[torch.Tensor]):
@ -682,203 +703,212 @@ class TestUnion(JitTestCase):
code = template.format(ann=ann, lhs=lhs) code = template.format(ann=ann, lhs=lhs)
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
cu = torch.jit.CompilationUnit(code, _frames_up=1) cu = torch.jit.CompilationUnit(code, _frames_up=1)
string_frontend = getattr(cu, "fn") # noqa: B009 string_frontend = getattr(cu, "fn") # noqa: B009
def test_union_with_list_assignment(self): def test_union_with_list_assignment(self):
template = dedent(''' template = dedent(
"""
def fn(): def fn():
x: {ann} = {lhs} x: {ann} = {lhs}
if torch.jit.isinstance(x, List[torch.Tensor]): if torch.jit.isinstance(x, List[torch.Tensor]):
x.append(torch.tensor(3)) x.append(torch.tensor(3))
return x return x
''') """
)
lhs = {"list_literal_empty" : "[]", lhs = {
"list_literal_empty": "[]",
"list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]", "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
"list_literal_of_str": '["foo", "bar", "baz"]',
"list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]", "list_literal_of_mixed": "[torch.arange(5), 1]",
"list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
"list_literal_of_mixed" : "[torch.arange(5), 1]", "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
"list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
"list_comprehension_of_tensor" : }
"[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
"list_comprehension_of_str" :
"[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
"list_comprehension_of_mixed" :
"[torch.add(1, x) for x in [torch.arange(5), 1]]"}
""" """
Union[List[str], List[torch.Tensor]] Union[List[str], List[torch.Tensor]]
""" """
self._assert_raises(template, self._assert_raises(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_literal_empty"], "Union[List[str], List[torch.Tensor]]",
"there are multiple possible List type " lhs["list_literal_empty"],
"candidates in the Union annotation") "there are multiple possible List type "
"candidates in the Union annotation",
)
self._assert_passes(template, self._assert_passes(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_literal_of_tensor"]) "Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_tensor"],
)
self._assert_passes(template, self._assert_passes(
"Union[List[str], List[torch.Tensor]]", template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
lhs["list_literal_of_str"]) )
self._assert_raises(template, self._assert_raises(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_literal_of_mixed"], "Union[List[str], List[torch.Tensor]]",
"none of those types match the types of the" lhs["list_literal_of_mixed"],
" given list elements") "none of those types match the types of the" " given list elements",
)
self._assert_passes(template, self._assert_passes(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_comprehension_of_tensor"]) "Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_tensor"],
)
self._assert_passes(template, self._assert_passes(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_comprehension_of_str"]) "Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_str"],
)
# TODO: Support mixed list comprehensions # TODO: Support mixed list comprehensions
self._assert_raises(template, self._assert_raises(
"Union[List[str], List[torch.Tensor]]", template,
lhs["list_comprehension_of_mixed"], "Union[List[str], List[torch.Tensor]]",
"Arguments for call are not valid") lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid",
)
""" """
Union[int, torch.Tensor] Union[int, torch.Tensor]
""" """
self._assert_raises(template, self._assert_raises(
"Union[int, torch.Tensor]", template,
lhs["list_literal_empty"], "Union[int, torch.Tensor]",
"Expected an Union type annotation with an " lhs["list_literal_empty"],
"inner List type") "Expected an Union type annotation with an " "inner List type",
)
self._assert_raises(template, "Union[int, torch.Tensor]", self._assert_raises(
lhs["list_literal_of_tensor"], template,
"Expected an Union type annotation with an " "Union[int, torch.Tensor]",
"inner List type") lhs["list_literal_of_tensor"],
"Expected an Union type annotation with an " "inner List type",
)
self._assert_raises(template, "Union[int, torch.Tensor]", self._assert_raises(
lhs["list_comprehension_of_tensor"], template,
"Expected an Union type annotation with an " "Union[int, torch.Tensor]",
"inner List type") lhs["list_comprehension_of_tensor"],
"Expected an Union type annotation with an " "inner List type",
)
""" """
Union[List[torch.Tensor], int] Union[List[torch.Tensor], int]
""" """
self._assert_passes(template, self._assert_passes(
"Union[List[torch.Tensor], int]", template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
lhs["list_literal_empty"]) )
self._assert_passes(template, self._assert_passes(
"Union[List[torch.Tensor], int]", template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
lhs["list_literal_of_tensor"]) )
self._assert_raises(template, "Union[List[torch.Tensor], int]", self._assert_raises(
lhs["list_literal_of_str"], template,
r"List type annotation `List\[Tensor\]` did " "Union[List[torch.Tensor], int]",
"not match the types of the given list " lhs["list_literal_of_str"],
"elements") r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements",
)
self._assert_raises(template, "Union[List[torch.Tensor], int]", self._assert_raises(
lhs["list_literal_of_mixed"], template,
r"List type annotation `List\[Tensor\]` did " "Union[List[torch.Tensor], int]",
"not match the types of the given list " lhs["list_literal_of_mixed"],
"elements") r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements",
)
self._assert_passes(template, self._assert_passes(
"Union[List[torch.Tensor], int]", template,
lhs["list_comprehension_of_tensor"]) "Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_tensor"],
)
self._assert_raises(template, self._assert_raises(
"Union[List[torch.Tensor], int]", template,
lhs["list_comprehension_of_str"], "Union[List[torch.Tensor], int]",
r"List type annotation `List\[Tensor\]` did " lhs["list_comprehension_of_str"],
"not match the types of the given list " r"List type annotation `List\[Tensor\]` did "
"elements") "not match the types of the given list "
"elements",
)
# TODO(@ansley): Support mixed list comprehensions # TODO(@ansley): Support mixed list comprehensions
self._assert_raises(template, self._assert_raises(
"Union[List[torch.Tensor], int]", template,
lhs["list_comprehension_of_mixed"], "Union[List[torch.Tensor], int]",
"Arguments for call are not valid") lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid",
)
def test_union_with_dict_assignment(self): def test_union_with_dict_assignment(self):
template = dedent(''' template = dedent(
"""
def fn(): def fn():
x: {ann} = {lhs} x: {ann} = {lhs}
if torch.jit.isinstance(x, Dict[str, torch.Tensor]): if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
x["foo"] = torch.tensor(3) x["foo"] = torch.tensor(3)
return x return x
''') """
)
lhs = {"dict_literal_empty" : "{}", lhs = {
"dict_literal_empty": "{}",
"dict_literal_of_str_tensor" : "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
"{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}", "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
"dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
"dict_literal_of_str_int" : "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
"{\"foo\" : 1, \"bar\" : 2}", zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
"dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
"dict_literal_of_mixed" : zip(["foo", "bar"], [1, 2]}',
"{\"foo\" : torch.arange(3), \"bar\" : 2}", "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
zip(["foo", "bar"], [torch.arange(3), 2])}',
"dict_comprehension_of_str_tensor" : "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
"{x : torch.add(y, 1) for x, y in \ "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}", "dict_keyword_with_empty_iterable": "dict([])",
"dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
"dict_comprehension_of_str_int" : "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
"{x : torch.add(y, 1) for x, y in \ "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
zip([\"foo\", \"bar\"], [1, 2]}", }
"dict_comprehension_of_mixed" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
"dict_keyword" :
"dict(foo=torch.arange(3), baz=torch.arange(5))",
"dict_keyword_with_iterable" :
"dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",
"dict_keyword_with_empty_iterable" :
"dict([])",
"dict_keyword_with_internal_aggregate_function" :
"dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",
"dict_keyword_with_mapping" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",
"dict_keyword_with_mapping_and_kwargs" :
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",
}
""" """
Union[Dict[str, torch.Tensor], Dict[str, int]] Union[Dict[str, torch.Tensor], Dict[str, int]]
""" """
self._assert_raises(template, self._assert_raises(
"Union[List[str], List[torch.Tensor]]", template,
lhs["dict_literal_empty"], "Union[List[str], List[torch.Tensor]]",
"Expected an Union type annotation with an " lhs["dict_literal_empty"],
"inner Dict type") "Expected an Union type annotation with an " "inner Dict type",
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_literal_of_str_tensor"]) "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_tensor"],
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_literal_of_str_int"]) "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_int"],
)
self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", self._assert_raises(
lhs["dict_literal_of_mixed"], template,
"none of those dict types can hold the " "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"types of the given keys and values") lhs["dict_literal_of_mixed"],
"none of those dict types can hold the "
"types of the given keys and values",
)
# TODO: String frontend does not support tuple unpacking # TODO: String frontend does not support tuple unpacking
# https://github.com/pytorch/pytorch/issues/64096 # https://github.com/pytorch/pytorch/issues/64096
@ -899,45 +929,57 @@ class TestUnion(JitTestCase):
# TODO(@ansley): Follow-up project needed for full type # TODO(@ansley): Follow-up project needed for full type
# inference with dict keyword (supported for dict comprehension # inference with dict keyword (supported for dict comprehension
# and dict literal already; should not be a blocker for anyone) # and dict literal already; should not be a blocker for anyone)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_keyword"], "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"full type inference is not yet supported") lhs["dict_keyword"],
"full type inference is not yet supported",
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_keyword_with_iterable"], "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"full type inference is not yet supported") lhs["dict_keyword_with_iterable"],
"full type inference is not yet supported",
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_keyword_with_empty_iterable"], "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"full type inference is not yet supported") lhs["dict_keyword_with_empty_iterable"],
"full type inference is not yet supported",
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_keyword_with_mapping"], "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"full type inference is not yet supported") lhs["dict_keyword_with_mapping"],
"full type inference is not yet supported",
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], Dict[str, int]]", template,
lhs["dict_keyword_with_mapping_and_kwargs"], "Union[Dict[str, torch.Tensor], Dict[str, int]]",
"full type inference is not yet supported") lhs["dict_keyword_with_mapping_and_kwargs"],
"full type inference is not yet supported",
)
""" """
Union[int, torch.Tensor] Union[int, torch.Tensor]
""" """
self._assert_raises(template, self._assert_raises(
"Union[int, torch.Tensor]", template,
lhs["dict_literal_empty"], "Union[int, torch.Tensor]",
"Expected an Union type annotation with " lhs["dict_literal_empty"],
"an inner Dict type") "Expected an Union type annotation with " "an inner Dict type",
)
self._assert_raises(template, self._assert_raises(
"Union[int, torch.Tensor]", template,
lhs["dict_literal_of_str_tensor"], "Union[int, torch.Tensor]",
"Expected an Union type annotation with " lhs["dict_literal_of_str_tensor"],
"an inner Dict type") "Expected an Union type annotation with " "an inner Dict type",
)
# See above--string frontend does not support tuple unpacking # See above--string frontend does not support tuple unpacking
# self._assert_raises(template, "Union[int, torch.Tensor]", # self._assert_raises(template, "Union[int, torch.Tensor]",
@ -947,47 +989,61 @@ class TestUnion(JitTestCase):
""" """
Union[Dict[str, torch.Tensor], int] Union[Dict[str, torch.Tensor], int]
""" """
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
lhs["dict_literal_empty"]) )
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_literal_of_str_tensor"]) "Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_tensor"],
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_literal_of_str_int"], "Union[Dict[str, torch.Tensor], int]",
"Type annotation was inferred to be " lhs["dict_literal_of_str_int"],
r"`Dict\[str, Tensor\]`, but the type of " "Type annotation was inferred to be "
"values given by the dict literal is") r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is",
)
self._assert_raises(template, self._assert_raises(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_literal_of_mixed"], "Union[Dict[str, torch.Tensor], int]",
"Type annotation was inferred to be " lhs["dict_literal_of_mixed"],
r"`Dict\[str, Tensor\]`, but the type of " "Type annotation was inferred to be "
"values given by the dict literal is") r"`Dict\[str, Tensor\]`, but the type of "
"values given by the dict literal is",
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
lhs["dict_keyword"]) )
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_keyword_with_iterable"]) "Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_iterable"],
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_keyword_with_empty_iterable"]) "Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_empty_iterable"],
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_keyword_with_mapping"]) "Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping"],
)
self._assert_passes(template, self._assert_passes(
"Union[Dict[str, torch.Tensor], int]", template,
lhs["dict_keyword_with_mapping_and_kwargs"]) "Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword_with_mapping_and_kwargs"],
)
# See above--string frontend does not support tuple unpacking # See above--string frontend does not support tuple unpacking
# self._assert_passes(template, # self._assert_passes(template,

View File

@ -2,19 +2,21 @@
import os import os
import sys import sys
import unittest
import torch import torch
import unittest
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# NOTE: FIXING FAILING TESTS # NOTE: FIXING FAILING TESTS
# If you are seeing a test failure from this file, congrats, you improved # If you are seeing a test failure from this file, congrats, you improved
@ -22,6 +24,7 @@ if __name__ == '__main__':
# the corresponding section in documentation that states the unsupported behavior. # the corresponding section in documentation that states the unsupported behavior.
# see: `jit_unsupported.rst` # see: `jit_unsupported.rst`
class TestUnsupportedOps(JitTestCase): class TestUnsupportedOps(JitTestCase):
def test_factory_ops_requires_grad_fail(self): def test_factory_ops_requires_grad_fail(self):
# Keyword argument {name} unknown is a JIT-only error message, # Keyword argument {name} unknown is a JIT-only error message,
@ -32,31 +35,31 @@ class TestUnsupportedOps(JitTestCase):
def ones(): def ones():
return torch.ones([2], requires_grad=True) return torch.ones([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
"Keyword argument requires_grad unknown", Exception, "Keyword argument requires_grad unknown", "torch.ones"
"torch.ones"): ):
torch.jit.script(ones) torch.jit.script(ones)
def randn(): def randn():
return torch.randn([2], requires_grad=True) return torch.randn([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
"Keyword argument requires_grad unknown", Exception, "Keyword argument requires_grad unknown", "torch.randn"
"torch.randn"): ):
torch.jit.script(randn) torch.jit.script(randn)
def zeros(): def zeros():
return torch.zeros([2], requires_grad=True) return torch.zeros([2], requires_grad=True)
with self.assertRaisesRegexWithHighlight(Exception, with self.assertRaisesRegexWithHighlight(
"Keyword argument requires_grad unknown", Exception, "Keyword argument requires_grad unknown", "torch.zeros"
"torch.zeros"): ):
torch.jit.script(zeros) torch.jit.script(zeros)
@unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack") @unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")
def test_init_ops(self): def test_init_ops(self):
def calculate_gain(): def calculate_gain():
return torch.nn.init.calculate_gain('leaky_relu', 0.2) return torch.nn.init.calculate_gain("leaky_relu", 0.2)
def eye_(): def eye_():
return torch.nn.init.eye_(torch.zeros([2, 2])) return torch.nn.init.eye_(torch.zeros([2, 2]))
@ -71,9 +74,16 @@ class TestUnsupportedOps(JitTestCase):
return torch.nn.init.orthogonal_(torch.empty(3, 5)) return torch.nn.init.orthogonal_(torch.empty(3, 5))
def sparse(): def sparse():
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=.1) return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1)
for func in [calculate_gain, eye_, dirac_, kaiming_uniform_, orthogonal_, sparse]: for func in [
calculate_gain,
eye_,
dirac_,
kaiming_uniform_,
orthogonal_,
sparse,
]:
# doesn't error in eager # doesn't error in eager
func() func()
with self.assertRaisesRegex(Exception, ""): with self.assertRaisesRegex(Exception, ""):

View File

@ -3,20 +3,24 @@
import io import io
import os import os
import sys import sys
import torch
import zipfile import zipfile
from torch.testing import FileCheck
from typing import Union from typing import Union
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable # Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestUpgraders(JitTestCase): class TestUpgraders(JitTestCase):
def _load_model_version(self, loaded_model): def _load_model_version(self, loaded_model):
@ -28,10 +32,10 @@ class TestUpgraders(JitTestCase):
# in a package between version 3 and 7. # in a package between version 3 and 7.
# So we have to check for both. # So we have to check for both.
try: try:
version = int(zipped_model.read('archive/version').decode("utf-8")) version = int(zipped_model.read("archive/version").decode("utf-8"))
return version return version
except KeyError: except KeyError:
version = int(zipped_model.read('archive/.data/version').decode("utf-8")) version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
return version return version
# TODO (tugsuu) We should ideally be generating this test cases. # TODO (tugsuu) We should ideally be generating this test cases.
@ -62,15 +66,23 @@ class TestUpgraders(JitTestCase):
upgrader_bumped_version = 3 upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2" upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema) dummy_entry = torch._C._UpgraderEntry(
upgrader_bumped_version, upgrader_name, upgrader_schema
)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry) torch._C._test_only_add_entry_to_op_version_map(
"aten::_test_serialization_subcmul", dummy_entry
)
map_after_test = torch._C._get_operator_version_map() map_after_test = torch._C._get_operator_version_map()
self.assertTrue("aten::_test_serialization_subcmul" in map_after_test) self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
self.assertTrue(len(map_after_test) - len(map_before_test) == 1) self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul") torch._C._test_only_remove_entry_to_op_version_map(
"aten::_test_serialization_subcmul"
)
map_after_remove_test = torch._C._get_operator_version_map() map_after_remove_test = torch._C._get_operator_version_map()
self.assertTrue("aten::_test_serialization_subcmul" not in map_after_remove_test) self.assertTrue(
"aten::_test_serialization_subcmul" not in map_after_remove_test
)
self.assertEqual(len(map_after_remove_test), len(map_before_test)) self.assertEqual(len(map_after_remove_test), len(map_before_test))
def test_populated_test_upgrader_graph(self): def test_populated_test_upgrader_graph(self):
@ -151,7 +163,7 @@ class TestUpgraders(JitTestCase):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl" model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs: for a, b in sample_inputs:
output_with_step, output_without_step = loaded_model(a, b) output_with_step, output_without_step = loaded_model(a, b)
# when no step is given, should have used 100 # when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100) self.assertTrue(output_without_step.size(dim=0) == 100)
@ -161,7 +173,9 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 8) self.assertTrue(version == 8)
def test_aten_linspace_out(self): def test_aten_linspace_out(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
sample_inputs = ( sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64)), (3, 10, torch.empty((100,), dtype=torch.int64)),
@ -169,7 +183,7 @@ class TestUpgraders(JitTestCase):
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)), (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
) )
for (a, b, c) in sample_inputs: for a, b, c in sample_inputs:
output = loaded_model(a, b, c) output = loaded_model(a, b, c)
# when no step is given, should have used 100 # when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100) self.assertTrue(output.size(dim=0) == 100)
@ -181,7 +195,7 @@ class TestUpgraders(JitTestCase):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl" model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j)) sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
for (a, b) in sample_inputs: for a, b in sample_inputs:
output_with_step, output_without_step = loaded_model(a, b) output_with_step, output_without_step = loaded_model(a, b)
# when no step is given, should have used 100 # when no step is given, should have used 100
self.assertTrue(output_without_step.size(dim=0) == 100) self.assertTrue(output_without_step.size(dim=0) == 100)
@ -191,7 +205,9 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 9) self.assertTrue(version == 9)
def test_aten_logspace_out(self): def test_aten_logspace_out(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
sample_inputs = ( sample_inputs = (
(3, 10, torch.empty((100,), dtype=torch.int64)), (3, 10, torch.empty((100,), dtype=torch.int64)),
@ -199,7 +215,7 @@ class TestUpgraders(JitTestCase):
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)), (4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)), (3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
) )
for (a, b, c) in sample_inputs: for a, b, c in sample_inputs:
output = loaded_model(a, b, c) output = loaded_model(a, b, c)
# when no step is given, should have used 100 # when no step is given, should have used 100
self.assertTrue(output.size(dim=0) == 100) self.assertTrue(output.size(dim=0) == 100)
@ -208,21 +224,36 @@ class TestUpgraders(JitTestCase):
self.assertTrue(version == 9) self.assertTrue(version == 9)
def test_aten_test_serialization(self): def test_aten_test_serialization(self):
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt" model_path = (
pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
)
# add test version entry to the version map # add test version entry to the version map
upgrader_bumped_version = 3 upgrader_bumped_version = 3
upgrader_name = "_test_serialization_subcmul_0_2" upgrader_name = "_test_serialization_subcmul_0_2"
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor" upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema) dummy_entry = torch._C._UpgraderEntry(
upgrader_bumped_version, upgrader_name, upgrader_schema
)
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry) torch._C._test_only_add_entry_to_op_version_map(
"aten::_test_serialization_subcmul", dummy_entry
)
# add test upgrader in the upgraders map # add test upgrader in the upgraders map
@torch.jit.script @torch.jit.script
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor: def _test_serialization_subcmul_0_2(
self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2
) -> torch.Tensor:
return other - (self * alpha) return other - (self * alpha)
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
torch._C._test_only_populate_upgraders(
{
"_test_serialization_subcmul_0_2": str(
_test_serialization_subcmul_0_2.graph
)
}
)
# test if the server is able to find the test upgraders and apply to IR # test if the server is able to find the test upgraders and apply to IR
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
@ -238,11 +269,21 @@ class TestUpgraders(JitTestCase):
# we check by its' code because graph variable names # we check by its' code because graph variable names
# can be different every time # can be different every time
self.assertEqual(loaded_model.code, loaded_model_twice.code) self.assertEqual(loaded_model.code, loaded_model_twice.code)
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul") torch._C._test_only_remove_entry_to_op_version_map(
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)}) "aten::_test_serialization_subcmul"
)
torch._C._test_only_remove_upgraders(
{
"_test_serialization_subcmul_0_2": str(
_test_serialization_subcmul_0_2.graph
)
}
)
def test_aten_div_scalar_at_3(self): def test_aten_div_scalar_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph) FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph) FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
@ -254,11 +295,15 @@ class TestUpgraders(JitTestCase):
self.assertEqual(version, 4) self.assertEqual(version, 4)
loaded_model_twice = torch.jit.load(buffer) loaded_model_twice = torch.jit.load(buffer)
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0), self.assertEqual(
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0)) loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0),
)
def test_aten_div_tensor_out_at_3(self): def test_aten_div_tensor_out_at_3(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
FileCheck().check("prim::If").run(loaded_model.graph) FileCheck().check("prim::If").run(loaded_model.graph)
FileCheck().check_count("aten::div", 2).run(loaded_model.graph) FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
@ -274,7 +319,9 @@ class TestUpgraders(JitTestCase):
self.assertEqual(loaded_model.code, loaded_model_twice.code) self.assertEqual(loaded_model.code, loaded_model_twice.code)
def test_aten_full_at_4(self): def test_aten_full_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph) FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
FileCheck().check_count("aten::full", 2).run(loaded_model.graph) FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
@ -290,7 +337,9 @@ class TestUpgraders(JitTestCase):
self.assertEqual(loaded_model.code, loaded_model_twice.code) self.assertEqual(loaded_model.code, loaded_model_twice.code)
def test_aten_full_out_at_4(self): def test_aten_full_out_at_4(self):
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt" model_path = (
pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
)
loaded_model = torch.jit.load(model_path) loaded_model = torch.jit.load(model_path)
FileCheck().check_count("aten::full", 5).run(loaded_model.graph) FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
version = self._load_model_version(loaded_model) version = self._load_model_version(loaded_model)

View File

@ -1,12 +1,12 @@
# Owner(s): ["oncall: jit"] # Owner(s): ["oncall: jit"]
import io
import os import os
import sys import sys
import io
import torch
import warnings import warnings
from contextlib import redirect_stderr from contextlib import redirect_stderr
import torch
from torch.testing import FileCheck from torch.testing import FileCheck
# Make the helper files in test/ importable # Make the helper files in test/ importable
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__': if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n" raise RuntimeError(
"\tpython test/test_jit.py TESTNAME\n\n" "This test file is not meant to be run directly, use:\n\n"
"instead.") "\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWarn(JitTestCase): class TestWarn(JitTestCase):
@ -30,12 +32,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f): with redirect_stderr(f):
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=1, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_only_once(self): def test_warn_only_once(self):
@torch.jit.script @torch.jit.script
@ -47,12 +46,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f): with redirect_stderr(f):
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=1, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_only_once_in_loop_func(self): def test_warn_only_once_in_loop_func(self):
def w(): def w():
@ -67,12 +63,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f): with redirect_stderr(f):
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=1, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_once_per_func(self): def test_warn_once_per_func(self):
def w1(): def w1():
@ -90,12 +83,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f): with redirect_stderr(f):
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=2, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_once_per_func_in_loop(self): def test_warn_once_per_func_in_loop(self):
def w1(): def w1():
@ -114,12 +104,9 @@ class TestWarn(JitTestCase):
with redirect_stderr(f): with redirect_stderr(f):
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=2, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self): def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script @torch.jit.script
@ -131,12 +118,9 @@ class TestWarn(JitTestCase):
fn() fn()
fn() fn()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you", count=2, exactly=True
str="UserWarning: I am warning you", ).run(f.getvalue())
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self): def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str): def warn(caller: str):
@ -155,13 +139,10 @@ class TestWarn(JitTestCase):
foo() foo()
bar() bar()
FileCheck() \ FileCheck().check_count(
.check_count( str="UserWarning: I am warning you from foo", count=1, exactly=True
str="UserWarning: I am warning you from foo", ).check_count(
count=1, str="UserWarning: I am warning you from bar", count=1, exactly=True
exactly=True) \ ).run(
.check_count( f.getvalue()
str="UserWarning: I am warning you from bar", )
count=1,
exactly=True) \
.run(f.getvalue())

View File

@ -32,6 +32,7 @@ class TestWith(JitTestCase):
Check that with statements that use the 'as' keyword to bind expressions Check that with statements that use the 'as' keyword to bind expressions
to targets work as expected. to targets work as expected.
""" """
@torch.jit.script @torch.jit.script
class Context: class Context:
""" """
@ -189,6 +190,7 @@ class TestWith(JitTestCase):
Check that with statements that do not use the 'as' keyword to bind expressions Check that with statements that do not use the 'as' keyword to bind expressions
to targets work as expected. to targets work as expected.
""" """
@torch.jit.script @torch.jit.script
class Context: class Context:
""" """
@ -345,6 +347,7 @@ class TestWith(JitTestCase):
Check that exceptions thrown in the bodies of with-statements are Check that exceptions thrown in the bodies of with-statements are
handled correctly. handled correctly.
""" """
@torch.jit.script @torch.jit.script
class Context: class Context:
""" """
@ -416,15 +419,21 @@ class TestWith(JitTestCase):
# checkScript and checkScriptRaisesRegex cannot be used because the string frontend will # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
# not compile class types (of which Context, the context manager being used for this test # not compile class types (of which Context, the context manager being used for this test
# is one). # is one).
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"): with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception(torch.randn(2), c) test_exception(torch.randn(2), c)
self.assertEqual(c.count, 1) self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"): with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception_nested(torch.randn(2), c) test_exception_nested(torch.randn(2), c)
self.assertEqual(c.count, 1) self.assertEqual(c.count, 1)
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"): with self.assertRaisesRegexWithHighlight(
Exception, r"raised exception", 'raise Exception("raised exception'
):
test_exception_fn_call(torch.randn(2), c) test_exception_fn_call(torch.randn(2), c)
self.assertEqual(c.count, 1) self.assertEqual(c.count, 1)
@ -505,7 +514,9 @@ class TestWith(JitTestCase):
return x return x
def test_exit_incorrect_types(x: torch.Tensor, cm: ExitIncorrectTypes) -> torch.Tensor: def test_exit_incorrect_types(
x: torch.Tensor, cm: ExitIncorrectTypes
) -> torch.Tensor:
with cm as _: with cm as _:
pass pass
@ -523,7 +534,9 @@ class TestWith(JitTestCase):
self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit())) self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
with self.assertRaisesRegexWithHighlight( with self.assertRaisesRegexWithHighlight(
RuntimeError, r"__enter__ must have only one argument and one return value", "cm" RuntimeError,
r"__enter__ must have only one argument and one return value",
"cm",
): ):
self.checkScript(test_bad_enter, (test_tensor, BadEnter())) self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
@ -539,7 +552,9 @@ class TestWith(JitTestCase):
test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes()) test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
) )
with self.assertRaisesRegexWithHighlight(RuntimeError, r"must return an object", "\"not_object\""): with self.assertRaisesRegexWithHighlight(
RuntimeError, r"must return an object", '"not_object"'
):
self.checkScript(test_enter_without_object, ()) self.checkScript(test_enter_without_object, ())
def test_with_no_grad(self): def test_with_no_grad(self):
@ -603,6 +618,7 @@ class TestWith(JitTestCase):
Check that torch.autograd.profiler.record_function context manager is Check that torch.autograd.profiler.record_function context manager is
torchscriptable. torchscriptable.
""" """
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
with torch.autograd.profiler.record_function("foo"): with torch.autograd.profiler.record_function("foo"):
# Nested record_function. # Nested record_function.

View File

@ -7,6 +7,7 @@ import torch._C
torch.ops.load_library("//caffe2:xnnpack_backend") torch.ops.load_library("//caffe2:xnnpack_backend")
class TestXNNPackBackend(unittest.TestCase): class TestXNNPackBackend(unittest.TestCase):
def test_xnnpack_constant_data(self): def test_xnnpack_constant_data(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
@ -24,17 +25,19 @@ class TestXNNPackBackend(unittest.TestCase):
scripted_module, scripted_module,
{ {
"forward": { "forward": {
"inputs" : [torch.randn(4, 4, 4)], "inputs": [torch.randn(4, 4, 4)],
"outputs": [torch.randn(4, 4, 4)] "outputs": [torch.randn(4, 4, 4)],
} }
} },
) )
for i in range(0, 20): for i in range(0, 20):
sample_input = torch.randn(4, 4, 4) sample_input = torch.randn(4, 4, 4)
actual_output = scripted_module(sample_input) actual_output = scripted_module(sample_input)
expected_output = lowered_module(sample_input) expected_output = lowered_module(sample_input)
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_lowering(self): def test_xnnpack_lowering(self):
class Module(torch.nn.Module): class Module(torch.nn.Module):
@ -45,13 +48,11 @@ class TestXNNPackBackend(unittest.TestCase):
faulty_compile_spec = { faulty_compile_spec = {
"backward": { "backward": {
"inputs" : [torch.zeros(1)], "inputs": [torch.zeros(1)],
"outputs": [torch.zeros(1)], "outputs": [torch.zeros(1)],
} }
} }
error_msg = ( error_msg = 'method_compile_spec does not contain the "forward" key.'
"method_compile_spec does not contain the \"forward\" key."
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
@ -64,21 +65,21 @@ class TestXNNPackBackend(unittest.TestCase):
) )
mismatch_compile_spec = { mismatch_compile_spec = {
"forward" : { "forward": {
"inputs" : [torch.zeros(1), torch.zeros(1)], "inputs": [torch.zeros(1), torch.zeros(1)],
"outputs" : [torch.zeros(1)] "outputs": [torch.zeros(1)],
} }
} }
error_msg = ("method_compile_spec inputs do not match expected number of forward inputs") error_msg = (
"method_compile_spec inputs do not match expected number of forward inputs"
)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
error_msg, error_msg,
): ):
_ = torch._C._jit_to_backend( _ = torch._C._jit_to_backend(
"xnnpack", "xnnpack", scripted_module, mismatch_compile_spec
scripted_module,
mismatch_compile_spec
) )
lowered = torch._C._jit_to_backend( lowered = torch._C._jit_to_backend(
@ -86,10 +87,10 @@ class TestXNNPackBackend(unittest.TestCase):
scripted_module, scripted_module,
{ {
"forward": { "forward": {
"inputs" : [torch.zeros(1)], "inputs": [torch.zeros(1)],
"outputs": [torch.zeros(1)], "outputs": [torch.zeros(1)],
} }
} },
) )
lowered(torch.zeros(1)) lowered(torch.zeros(1))
@ -113,14 +114,16 @@ class TestXNNPackBackend(unittest.TestCase):
add_module, add_module,
{ {
"forward": { "forward": {
"inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()], "inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()],
"outputs": [sample_output] "outputs": [sample_output],
} }
} },
) )
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_broadcasting(self): def test_xnnpack_broadcasting(self):
class AddModule(torch.nn.Module): class AddModule(torch.nn.Module):
@ -139,14 +142,16 @@ class TestXNNPackBackend(unittest.TestCase):
add_module, add_module,
{ {
"forward": { "forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]], "inputs": [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output] "outputs": [sample_output],
} }
} },
) )
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1]) actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) self.assertTrue(
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
)
def test_xnnpack_unsupported(self): def test_xnnpack_unsupported(self):
class AddSpliceModule(torch.nn.Module): class AddSpliceModule(torch.nn.Module):
@ -173,8 +178,8 @@ class TestXNNPackBackend(unittest.TestCase):
add_module, add_module,
{ {
"forward": { "forward": {
"inputs" : [sample_inputs[0], sample_inputs[1]], "inputs": [sample_inputs[0], sample_inputs[1]],
"outputs": [sample_output] "outputs": [sample_output],
} }
} },
) )

View File

@ -1,23 +1,30 @@
import argparse import argparse
import os import os
import sys import sys
import torch import torch
# grab modules from test_jit_hooks.cpp # grab modules from test_jit_hooks.cpp
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir) sys.path.append(pytorch_test_dir)
from jit.test_hooks_modules import ( from jit.test_hooks_modules import (
create_forward_tuple_input, create_module_forward_multiple_inputs, create_forward_tuple_input,
create_module_forward_single_input, create_module_hook_return_nothing, create_module_forward_multiple_inputs,
create_module_forward_single_input,
create_module_hook_return_nothing,
create_module_multiple_hooks_multiple_inputs, create_module_multiple_hooks_multiple_inputs,
create_module_multiple_hooks_single_input, create_module_no_forward_input, create_module_multiple_hooks_single_input,
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs, create_module_no_forward_input,
create_module_same_hook_repeated,
create_submodule_forward_multiple_inputs,
create_submodule_forward_single_input, create_submodule_forward_single_input,
create_submodule_hook_return_nothing, create_submodule_hook_return_nothing,
create_submodule_multiple_hooks_multiple_inputs, create_submodule_multiple_hooks_multiple_inputs,
create_submodule_multiple_hooks_single_input, create_submodule_multiple_hooks_single_input,
create_submodule_same_hook_repeated, create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks) create_submodule_to_call_directly_with_hooks,
)
# Create saved modules for JIT forward hooks and pre-hooks # Create saved modules for JIT forward hooks and pre-hooks
def main(): def main():
@ -30,23 +37,45 @@ def main():
save_name = options.export_script_module_to + "_" save_name = options.export_script_module_to + "_"
tests = [ tests = [
("test_submodule_forward_single_input", create_submodule_forward_single_input()), (
("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()), "test_submodule_forward_single_input",
("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()), create_submodule_forward_single_input(),
("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()), ),
(
"test_submodule_forward_multiple_inputs",
create_submodule_forward_multiple_inputs(),
),
(
"test_submodule_multiple_hooks_single_input",
create_submodule_multiple_hooks_single_input(),
),
(
"test_submodule_multiple_hooks_multiple_inputs",
create_submodule_multiple_hooks_multiple_inputs(),
),
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()), ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()), ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
("test_module_forward_single_input", create_module_forward_single_input()), ("test_module_forward_single_input", create_module_forward_single_input()),
("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()), (
("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()), "test_module_forward_multiple_inputs",
("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()), create_module_forward_multiple_inputs(),
),
(
"test_module_multiple_hooks_single_input",
create_module_multiple_hooks_single_input(),
),
(
"test_module_multiple_hooks_multiple_inputs",
create_module_multiple_hooks_multiple_inputs(),
),
("test_module_hook_return_nothing", create_module_hook_return_nothing()), ("test_module_hook_return_nothing", create_module_hook_return_nothing()),
("test_module_same_hook_repeated", create_module_same_hook_repeated()), ("test_module_same_hook_repeated", create_module_same_hook_repeated()),
("test_module_no_forward_input", create_module_no_forward_input()), ("test_module_no_forward_input", create_module_no_forward_input()),
("test_forward_tuple_input", create_forward_tuple_input()), ("test_forward_tuple_input", create_forward_tuple_input()),
("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks()) (
"test_submodule_to_call_directly_with_hooks",
create_submodule_to_call_directly_with_hooks(),
),
] ]
for name, model in tests: for name, model in tests: