mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on all of test/jit (#123623)
Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0ccf599cc
commit
604c9c5601
@ -1162,94 +1162,6 @@ exclude_patterns = [
|
||||
'test/functorch/test_vmap.py',
|
||||
'test/functorch/test_vmap_registrations.py',
|
||||
'test/functorch/xfail_suggester.py',
|
||||
'test/jit/__init__.py',
|
||||
'test/jit/_imported_class_test/__init__.py',
|
||||
'test/jit/_imported_class_test/bar.py',
|
||||
'test/jit/_imported_class_test/foo.py',
|
||||
'test/jit/_imported_class_test/very/__init__.py',
|
||||
'test/jit/_imported_class_test/very/very/__init__.py',
|
||||
'test/jit/_imported_class_test/very/very/nested.py',
|
||||
'test/jit/fixtures_srcs/__init__.py',
|
||||
'test/jit/fixtures_srcs/fixtures_src.py',
|
||||
'test/jit/fixtures_srcs/generate_models.py',
|
||||
'test/jit/fixtures_srcs/test_upgrader_models_generation.py',
|
||||
'test/jit/myexception.py',
|
||||
'test/jit/test_alias_analysis.py',
|
||||
'test/jit/test_async.py',
|
||||
'test/jit/test_aten_pow.py',
|
||||
'test/jit/test_attr.py',
|
||||
'test/jit/test_autodiff.py',
|
||||
'test/jit/test_autodiff_subgraph_slicing.py',
|
||||
'test/jit/test_await.py',
|
||||
'test/jit/test_backend_nnapi.py',
|
||||
'test/jit/test_backends.py',
|
||||
'test/jit/test_batch_mm.py',
|
||||
'test/jit/test_builtins.py',
|
||||
'test/jit/test_class_type.py',
|
||||
'test/jit/test_complex.py',
|
||||
'test/jit/test_complexity.py',
|
||||
'test/jit/test_convert_activation.py',
|
||||
'test/jit/test_cuda.py',
|
||||
'test/jit/test_custom_operators.py',
|
||||
'test/jit/test_data_parallel.py',
|
||||
'test/jit/test_dataclasses.py',
|
||||
'test/jit/test_dce.py',
|
||||
'test/jit/test_device_analysis.py',
|
||||
'test/jit/test_dtype_analysis.py',
|
||||
'test/jit/test_enum.py',
|
||||
'test/jit/test_exception.py',
|
||||
'test/jit/test_freezing.py',
|
||||
'test/jit/test_functional_blocks.py',
|
||||
'test/jit/test_fuser_common.py',
|
||||
'test/jit/test_graph_rewrite_passes.py',
|
||||
'test/jit/test_hash.py',
|
||||
'test/jit/test_hooks.py',
|
||||
'test/jit/test_hooks_modules.py',
|
||||
'test/jit/test_ignorable_args.py',
|
||||
'test/jit/test_ignore_context_manager.py',
|
||||
'test/jit/test_isinstance.py',
|
||||
'test/jit/test_jit_utils.py',
|
||||
'test/jit/test_list_dict.py',
|
||||
'test/jit/test_logging.py',
|
||||
'test/jit/test_misc.py',
|
||||
'test/jit/test_models.py',
|
||||
'test/jit/test_module_apis.py',
|
||||
'test/jit/test_module_containers.py',
|
||||
'test/jit/test_module_interface.py',
|
||||
'test/jit/test_modules.py',
|
||||
'test/jit/test_op_decompositions.py',
|
||||
'test/jit/test_optimize_for_mobile_preserve_debug_info.py',
|
||||
'test/jit/test_parametrization.py',
|
||||
'test/jit/test_pdt.py',
|
||||
'test/jit/test_peephole.py',
|
||||
'test/jit/test_profiler.py',
|
||||
'test/jit/test_python_bindings.py',
|
||||
'test/jit/test_python_builtins.py',
|
||||
'test/jit/test_python_ir.py',
|
||||
'test/jit/test_recursive_script.py',
|
||||
'test/jit/test_remove_mutation.py',
|
||||
'test/jit/test_save_load.py',
|
||||
'test/jit/test_save_load_for_op_version.py',
|
||||
'test/jit/test_script_profile.py',
|
||||
'test/jit/test_scriptmod_ann.py',
|
||||
'test/jit/test_slice.py',
|
||||
'test/jit/test_sparse.py',
|
||||
'test/jit/test_string_formatting.py',
|
||||
'test/jit/test_symbolic_shape_analysis.py',
|
||||
'test/jit/test_tensor_creation_ops.py',
|
||||
'test/jit/test_tensor_methods.py',
|
||||
'test/jit/test_torchbind.py',
|
||||
'test/jit/test_tracer.py',
|
||||
'test/jit/test_type_sharing.py',
|
||||
'test/jit/test_types.py',
|
||||
'test/jit/test_typing.py',
|
||||
'test/jit/test_union.py',
|
||||
'test/jit/test_unsupported_ops.py',
|
||||
'test/jit/test_upgraders.py',
|
||||
'test/jit/test_warn.py',
|
||||
'test/jit/test_with.py',
|
||||
'test/jit/xnnpack/test_xnnpack_delegate.py',
|
||||
'test/jit_hooks/model.py',
|
||||
'test/lazy/__init__.py',
|
||||
'test/lazy/test_bindings.py',
|
||||
'test/lazy/test_debug_util.py',
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
# This file contains definitions of script classes.
|
||||
# They are used by test_jit.py to test ScriptClass imports
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
|
||||
from . import bar
|
||||
|
||||
# This file contains definitions of script classes.
|
||||
# They are used by test_jit.py to test ScriptClass imports
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
|
||||
# This file contains definitions of script classes.
|
||||
# They are used by test_jit.py to test ScriptClass imports
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestVersionedDivTensorExampleV7(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
result_0 = a / b
|
||||
@ -8,35 +10,52 @@ class TestVersionedDivTensorExampleV7(torch.nn.Module):
|
||||
result_2 = a.div(b)
|
||||
return result_0, result_1, result_2
|
||||
|
||||
|
||||
class TestVersionedLinspaceV7(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
||||
c = torch.linspace(a, b, steps=5)
|
||||
d = torch.linspace(a, b)
|
||||
return c, d
|
||||
|
||||
|
||||
class TestVersionedLinspaceOutV7(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
a: Union[int, float, complex],
|
||||
b: Union[int, float, complex],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
return torch.linspace(a, b, out=out)
|
||||
|
||||
|
||||
class TestVersionedLogspaceV8(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
||||
c = torch.logspace(a, b, steps=5)
|
||||
d = torch.logspace(a, b)
|
||||
return c, d
|
||||
|
||||
|
||||
class TestVersionedLogspaceOutV8(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
a: Union[int, float, complex],
|
||||
b: Union[int, float, complex],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
return torch.logspace(a, b, out=out)
|
||||
|
||||
|
||||
class TestVersionedGeluV9(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch._C._nn.gelu(x)
|
||||
|
||||
|
||||
class TestVersionedGeluOutV9(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = torch.zeros_like(x)
|
||||
return torch._C._nn.gelu(x, out=out)
|
||||
|
||||
|
||||
class TestVersionedRandomV10(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
out = torch.zeros_like(x)
|
||||
|
||||
@ -6,9 +6,11 @@ from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
import torch
|
||||
|
||||
# Use asterisk symbol so developer doesn't need to import here when they add tests for upgraders.
|
||||
from test.jit.fixtures_srcs.fixtures_src import * # noqa: F403
|
||||
from torch.jit.mobile import _load_for_lite_interpreter, _export_operator_list
|
||||
from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -105,28 +107,41 @@ ALL_MODULES = {
|
||||
Get the path to `test/jit/fixtures`, where all test models for operator changes
|
||||
(upgrader/downgrader) are stored
|
||||
"""
|
||||
|
||||
|
||||
def get_fixtures_path() -> Path:
|
||||
pytorch_dir = Path(__file__).resolve().parents[3]
|
||||
fixtures_path = pytorch_dir / "test" / "jit" / "fixtures"
|
||||
return fixtures_path
|
||||
|
||||
|
||||
"""
|
||||
Get all models' name in `test/jit/fixtures`
|
||||
"""
|
||||
|
||||
|
||||
def get_all_models(model_directory_path: Path) -> Set[str]:
|
||||
files_in_fixtures = model_directory_path.glob('**/*')
|
||||
all_models_from_fixtures = [fixture.stem for fixture in files_in_fixtures if fixture.is_file()]
|
||||
files_in_fixtures = model_directory_path.glob("**/*")
|
||||
all_models_from_fixtures = [
|
||||
fixture.stem for fixture in files_in_fixtures if fixture.is_file()
|
||||
]
|
||||
return set(all_models_from_fixtures)
|
||||
|
||||
|
||||
"""
|
||||
Check if a given model already exist in `test/jit/fixtures`
|
||||
"""
|
||||
|
||||
|
||||
def model_exist(model_file_name: str, all_models: Set[str]) -> bool:
|
||||
return model_file_name in all_models
|
||||
|
||||
|
||||
"""
|
||||
Get the operator list given a module
|
||||
"""
|
||||
|
||||
|
||||
def get_operator_list(script_module: torch) -> Set[str]:
|
||||
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -134,21 +149,25 @@ def get_operator_list(script_module: torch) -> Set[str]:
|
||||
operator_list = _export_operator_list(mobile_module)
|
||||
return operator_list
|
||||
|
||||
|
||||
"""
|
||||
Get the output model operator version, given a module
|
||||
"""
|
||||
|
||||
|
||||
def get_output_model_version(script_module: torch.nn.Module) -> int:
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(script_module, buffer)
|
||||
buffer.seek(0)
|
||||
zipped_model = zipfile.ZipFile(buffer)
|
||||
try:
|
||||
version = int(zipped_model.read('archive/version').decode("utf-8"))
|
||||
version = int(zipped_model.read("archive/version").decode("utf-8"))
|
||||
return version
|
||||
except KeyError:
|
||||
version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
|
||||
version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
|
||||
return version
|
||||
|
||||
|
||||
"""
|
||||
Loop through all test modules. If the corresponding model doesn't exist in
|
||||
`test/jit/fixtures`, generate one. For the following reason, a model won't be exported:
|
||||
@ -165,6 +184,8 @@ likely this script is running with the commit to make the change.
|
||||
3. The model already exists in `test/jit/fixtures`.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def generate_models(model_directory_path: Path):
|
||||
all_models = get_all_models(model_directory_path)
|
||||
for a_module, expect_operator in ALL_MODULES.items():
|
||||
@ -176,13 +197,17 @@ def generate_models(model_directory_path: Path):
|
||||
"The module %s "
|
||||
"is not a torch.nn.module instance. "
|
||||
"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
|
||||
"and it's registered as an instance in ALL_MODULES in generated_models.py", torch_module_name)
|
||||
|
||||
"and it's registered as an instance in ALL_MODULES in generated_models.py",
|
||||
torch_module_name,
|
||||
)
|
||||
|
||||
# The corresponding model name is: test_versioned_div_tensor_example_v4
|
||||
model_name = ''.join([
|
||||
'_' + char.lower() if char.isupper() else char for char in torch_module_name
|
||||
]).lstrip('_')
|
||||
model_name = "".join(
|
||||
[
|
||||
"_" + char.lower() if char.isupper() else char
|
||||
for char in torch_module_name
|
||||
]
|
||||
).lstrip("_")
|
||||
|
||||
# Some models may not compile anymore, so skip the ones
|
||||
# that already has pt file for them.
|
||||
@ -199,7 +224,10 @@ def generate_models(model_directory_path: Path):
|
||||
logger.error(
|
||||
"Actual model version %s "
|
||||
"is equal or larger than %s + 1. "
|
||||
"Please run the script before the commit to change operator.", actual_model_version, current_operator_version)
|
||||
"Please run the script before the commit to change operator.",
|
||||
actual_model_version,
|
||||
current_operator_version,
|
||||
)
|
||||
continue
|
||||
|
||||
actual_operator_list = get_operator_list(script_module)
|
||||
@ -207,16 +235,23 @@ def generate_models(model_directory_path: Path):
|
||||
logger.error(
|
||||
"The model includes operator: %s, "
|
||||
"however it doesn't cover the operator %s."
|
||||
"Please ensure the output model includes the tested operator.", actual_operator_list, expect_operator)
|
||||
"Please ensure the output model includes the tested operator.",
|
||||
actual_operator_list,
|
||||
expect_operator,
|
||||
)
|
||||
continue
|
||||
|
||||
export_model_path = str(model_directory_path / (str(model_name) + ".ptl"))
|
||||
script_module._save_for_lite_interpreter(export_model_path)
|
||||
logger.info("Generating model %s and it's save to %s", model_name, export_model_path)
|
||||
logger.info(
|
||||
"Generating model %s and it's save to %s", model_name, export_model_path
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
model_directory_path = get_fixtures_path()
|
||||
generate_models(model_directory_path)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from test.jit.fixtures_srcs.generate_models import ALL_MODULES
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestUpgraderModelGeneration(TestCase):
|
||||
@ -14,7 +14,9 @@ class TestUpgraderModelGeneration(TestCase):
|
||||
f"The module {module_name} "
|
||||
f"is not a torch.nn.module instance. "
|
||||
f"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
|
||||
f"and it's registered as an instance in ALL_MODULES in generated_models.py")
|
||||
f"and it's registered as an instance in ALL_MODULES in generated_models.py",
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -3,5 +3,7 @@ Define exceptions used in test_exception.py. We define them in a
|
||||
separate file on purpose to make sure the fully qualified exception class name
|
||||
is captured correctly in suce cases.
|
||||
"""
|
||||
|
||||
|
||||
class MyKeyError(KeyError):
|
||||
pass
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch._C import parse_ir
|
||||
from torch.testing._internal.common_utils import TemporaryFileName
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch._C import parse_ir
|
||||
import torch
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestAliasAnalysis(JitTestCase):
|
||||
def test_becomes_wildcard_annotations(self):
|
||||
@ -26,9 +29,13 @@ class TestAliasAnalysis(JitTestCase):
|
||||
alias_db = graph.alias_db()
|
||||
split_node = graph.findNode("aten::split")
|
||||
# split input enters wildcard set, list initalized as containing wildcard set
|
||||
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), split_node.output()))
|
||||
self.assertTrue(
|
||||
alias_db.may_contain_alias(next(split_node.inputs()), split_node.output())
|
||||
)
|
||||
# because %x.1 enters wildcard set, it now aliases other members of wildcard set (graph inputs)
|
||||
self.assertTrue(alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs())))
|
||||
self.assertTrue(
|
||||
alias_db.may_contain_alias(next(split_node.inputs()), next(graph.inputs()))
|
||||
)
|
||||
|
||||
def test_nested_list_construct_not_wildcard(self):
|
||||
@torch.jit.script
|
||||
@ -42,7 +49,9 @@ class TestAliasAnalysis(JitTestCase):
|
||||
ten_construct = graph.findNode("aten::rand").output()
|
||||
output = next(graph.outputs())
|
||||
self.assertTrue(alias_db.may_contain_alias(ten_construct, output))
|
||||
self.assertFalse(alias_db.may_contain_alias(next(graph.inputs()), ten_construct))
|
||||
self.assertFalse(
|
||||
alias_db.may_contain_alias(next(graph.inputs()), ten_construct)
|
||||
)
|
||||
|
||||
def test_recursive_calls(self):
|
||||
@torch.jit.script
|
||||
@ -108,7 +117,9 @@ class TestAliasAnalysis(JitTestCase):
|
||||
class MultiTmpFile:
|
||||
def __init__(self, N):
|
||||
self.N = N
|
||||
self.ctxs = [TemporaryFileName(mode="w", suffix=".py") for _ in range(N)]
|
||||
self.ctxs = [
|
||||
TemporaryFileName(mode="w", suffix=".py") for _ in range(N)
|
||||
]
|
||||
|
||||
def __enter__(self):
|
||||
return [x.__enter__() for x in self.ctxs]
|
||||
|
||||
@ -3,18 +3,20 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, _inline_everything
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.jit import Future
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
|
||||
|
||||
|
||||
class TestAsync(JitTestCase):
|
||||
def test_async_python(self):
|
||||
@ -51,8 +53,7 @@ class TestAsync(JitTestCase):
|
||||
futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
|
||||
for _ in range(3):
|
||||
future = torch.jit.annotate(
|
||||
Future[List[Tensor]],
|
||||
torch.jit.fork(foo, x)
|
||||
Future[List[Tensor]], torch.jit.fork(foo, x)
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
@ -85,7 +86,7 @@ class TestAsync(JitTestCase):
|
||||
|
||||
def test_async_script_capture(self):
|
||||
class Mod(torch.jit.ScriptModule):
|
||||
__constants__ = ['const']
|
||||
__constants__ = ["const"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -139,7 +140,10 @@ class TestAsync(JitTestCase):
|
||||
def test_async_script_no_script_mod(self):
|
||||
x = torch.rand(3, 4)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, 'cannot call a value', 'torch.jit._fork(x'):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "cannot call a value", "torch.jit._fork(x"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def wait_script(x):
|
||||
fut = torch.jit._fork(x)
|
||||
@ -213,7 +217,7 @@ class TestAsync(JitTestCase):
|
||||
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
|
||||
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
|
||||
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
|
||||
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1))
|
||||
lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
|
||||
]:
|
||||
for wrapper in [
|
||||
func,
|
||||
@ -234,8 +238,8 @@ class TestAsync(JitTestCase):
|
||||
return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
|
||||
|
||||
for wrapper in [
|
||||
foo_script_args,
|
||||
foo_script_kwargs,
|
||||
foo_script_args,
|
||||
foo_script_kwargs,
|
||||
]:
|
||||
self.assertEqual(wrapper(x1, x2), y_hat)
|
||||
self.assertEqual(wrapper(x1, x2=x2), y_hat)
|
||||
@ -255,7 +259,9 @@ class TestAsync(JitTestCase):
|
||||
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
|
||||
def forward(
|
||||
self, x: Tensor
|
||||
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
|
||||
future1 = torch.jit._fork(self.traced, x)
|
||||
future2 = torch.jit._fork(torch.neg, x)
|
||||
|
||||
@ -284,10 +290,16 @@ class TestAsync(JitTestCase):
|
||||
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
|
||||
|
||||
# Make sure we have forks
|
||||
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
|
||||
self.assertGraphContainsExactly(
|
||||
module.graph, kind="prim::fork", num_kind_nodes=2
|
||||
)
|
||||
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
|
||||
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
|
||||
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
|
||||
self.assertGraphContainsExactly(
|
||||
module.graph, kind="aten::neg", num_kind_nodes=1
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
|
||||
)
|
||||
|
||||
y = torch.neg(x)
|
||||
self.assertEqual(module(x), (y, y, y, y, x, x))
|
||||
@ -311,19 +323,23 @@ class TestAsync(JitTestCase):
|
||||
return torch.jit._wait(fut)
|
||||
|
||||
# no future
|
||||
error_msg = 'The size.*must match the size of tensor'
|
||||
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'x.t() + x'):
|
||||
error_msg = "The size.*must match the size of tensor"
|
||||
with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
|
||||
foo(x)
|
||||
|
||||
# one future
|
||||
with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'torch.jit._fork(foo, x'):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, error_msg, "torch.jit._fork(foo, x"
|
||||
):
|
||||
wait_script(x)
|
||||
|
||||
# two futures with a different error
|
||||
x = torch.rand(3, 4, 5)
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
'expects a tensor with <= 2 dimensions',
|
||||
'torch.jit._fork(wait_script, x'):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception,
|
||||
"expects a tensor with <= 2 dimensions",
|
||||
"torch.jit._fork(wait_script, x",
|
||||
):
|
||||
wait_script_nest(x)
|
||||
|
||||
def test_async_grad_guard_with_grad(self):
|
||||
@ -381,9 +397,15 @@ class TestAsync(JitTestCase):
|
||||
x = torch.rand(3, 4)
|
||||
self.assertEqual(fn(x), traced(x))
|
||||
|
||||
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1)
|
||||
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1)
|
||||
self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="prim::fork", num_kind_nodes=1
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="aten::wait", num_kind_nodes=1
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
|
||||
)
|
||||
|
||||
def test_trace_fork_wait_leaking(self):
|
||||
my_list = []
|
||||
@ -397,9 +419,13 @@ class TestAsync(JitTestCase):
|
||||
val = torch.jit._wait(fut)
|
||||
return my_list[0]
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, 'did not have observable data dependence with trace inputs; '
|
||||
'this probably indicates your program cannot be understood '
|
||||
'by the tracer.', ''):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"did not have observable data dependence with trace inputs; "
|
||||
"this probably indicates your program cannot be understood "
|
||||
"by the tracer.",
|
||||
"",
|
||||
):
|
||||
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
|
||||
|
||||
def test_trace_fork_wait_inline(self):
|
||||
@ -413,9 +439,15 @@ class TestAsync(JitTestCase):
|
||||
|
||||
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
|
||||
torch._C._jit_pass_inline_fork_wait(traced.graph)
|
||||
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0)
|
||||
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0)
|
||||
self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="prim::fork", num_kind_nodes=0
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="aten::wait", num_kind_nodes=0
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
traced.graph, kind="aten::add", num_kind_nodes=2
|
||||
)
|
||||
|
||||
def test_trace_fork_wait_list_modulecalls(self):
|
||||
def add_one(input):
|
||||
@ -472,7 +504,10 @@ class TestAsync(JitTestCase):
|
||||
self.checkTrace(TestModule(), (torch.randn(5, 5),))
|
||||
|
||||
def test_no_future_subtype_message(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, 'Future without a contained type', ''):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Future without a contained type", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def forward(self, x):
|
||||
futs = torch.jit.annotate(List[torch.jit.Future], [])
|
||||
@ -481,6 +516,7 @@ class TestAsync(JitTestCase):
|
||||
"""
|
||||
Test that futures subtype each other properly.
|
||||
"""
|
||||
|
||||
# Successful subtyping.
|
||||
def returns_int(x: int) -> int:
|
||||
return x + x + 1
|
||||
@ -495,10 +531,11 @@ class TestAsync(JitTestCase):
|
||||
|
||||
# Unsuccessful subtyping.
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
|
||||
"fut = returns_future_float(x"
|
||||
RuntimeError,
|
||||
r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
|
||||
"fut = returns_future_float(x",
|
||||
):
|
||||
|
||||
def returns_future_float(x: int) -> torch.jit.Future[float]:
|
||||
return torch.jit._fork(returns_int, (x))
|
||||
|
||||
@ -508,8 +545,9 @@ class TestAsync(JitTestCase):
|
||||
return fut.wait()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
@ -3,35 +3,40 @@
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
class TestAtenPow(TestCase):
|
||||
def test_aten_pow_zero_negative_exponent(self):
|
||||
'''
|
||||
"""
|
||||
1. Testing a = int, b = int
|
||||
'''
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn_int_int(a: int, b: int):
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
# Existing correct behaviors of aten::pow
|
||||
self.assertEqual(fn_int_int(2, 1), 2 ** 1)
|
||||
self.assertEqual(fn_int_int(2, 0), 2 ** 0)
|
||||
self.assertEqual(fn_int_int(2, 1), 2**1)
|
||||
self.assertEqual(fn_int_int(2, 0), 2**0)
|
||||
self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
|
||||
self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
|
||||
self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
|
||||
self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
|
||||
self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
|
||||
self.assertEqual(fn_int_int(0, 2), 0 ** 1)
|
||||
self.assertEqual(fn_int_int(0, 0), 0 ** 0)
|
||||
self.assertEqual(fn_int_int(0, 2), 0**1)
|
||||
self.assertEqual(fn_int_int(0, 0), 0**0)
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_int_int, 0, -2)
|
||||
|
||||
'''
|
||||
"""
|
||||
2. Testing a = int, b = float
|
||||
'''
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn_int_float(a: int, b: float):
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
# Existing correct behaviors of aten::pow
|
||||
self.assertEqual(fn_int_float(2, 2.5), 2 ** 2.5)
|
||||
self.assertEqual(fn_int_float(2, 2.5), 2**2.5)
|
||||
self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
|
||||
self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
|
||||
self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
|
||||
@ -40,53 +45,57 @@ class TestAtenPow(TestCase):
|
||||
self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
|
||||
self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
|
||||
self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
|
||||
self.assertEqual(fn_int_float(0, 2.0), 0 ** 2.0)
|
||||
self.assertEqual(fn_int_float(0, 0.5), 0 ** 0.5)
|
||||
self.assertEqual(fn_int_float(0, 0.0), 0 ** 0.0)
|
||||
self.assertEqual(fn_int_float(0, 2.0), 0**2.0)
|
||||
self.assertEqual(fn_int_float(0, 0.5), 0**0.5)
|
||||
self.assertEqual(fn_int_float(0, 0.0), 0**0.0)
|
||||
self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)
|
||||
|
||||
'''
|
||||
"""
|
||||
3. Testing a = float, b = int
|
||||
'''
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn_float_int(a: float, b: int):
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
# Existing correct behaviors of aten::pow
|
||||
self.assertEqual(fn_float_int(2.5, 2), 2.5 ** 2)
|
||||
self.assertEqual(fn_float_int(2.5, 2), 2.5**2)
|
||||
self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
|
||||
self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
|
||||
self.assertEqual(fn_float_int(2.5, 0), 2.5 ** 0)
|
||||
self.assertEqual(fn_float_int(-2.5, 2), 2.5 ** 2)
|
||||
self.assertEqual(fn_float_int(2.5, 0), 2.5**0)
|
||||
self.assertEqual(fn_float_int(-2.5, 2), 2.5**2)
|
||||
self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
|
||||
self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
|
||||
self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
|
||||
self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
|
||||
self.assertEqual(fn_float_int(0.0, 2), 0 ** 2)
|
||||
self.assertEqual(fn_float_int(0.0, 0), 0 ** 0)
|
||||
self.assertEqual(fn_float_int(0.0, 2), 0**2)
|
||||
self.assertEqual(fn_float_int(0.0, 0), 0**0)
|
||||
self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)
|
||||
|
||||
'''
|
||||
"""
|
||||
4. Testing a = float, b = float
|
||||
'''
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn_float_float(a: float, b: float):
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
# Existing correct behaviors of aten::pow
|
||||
self.assertEqual(fn_float_float(2.5, 2.0), 2.5 ** 2.0)
|
||||
self.assertEqual(fn_float_float(2.5, 2.0), 2.5**2.0)
|
||||
self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
|
||||
self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
|
||||
self.assertEqual(fn_float_float(2.5, 0.0), 2.5 ** 0.0)
|
||||
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5 ** 2.0)
|
||||
self.assertEqual(fn_float_float(2.5, 0.0), 2.5**0.0)
|
||||
self.assertEqual(fn_float_float(-2.5, 2.0), 2.5**2.0)
|
||||
self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
|
||||
self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
|
||||
self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
|
||||
self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
|
||||
self.assertEqual(fn_float_float(0.0, 2.0), 0.0 ** 2.0)
|
||||
self.assertEqual(fn_float_float(0.0, 0.0), 0.0 ** 0.0)
|
||||
self.assertEqual(fn_float_float(0.0, 2.0), 0.0**2.0)
|
||||
self.assertEqual(fn_float_float(0.0, 0.0), 0.0**0.0)
|
||||
self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
|
||||
# zero base and negative exponent case that should trigger RunTimeError
|
||||
self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
|
||||
|
||||
@ -1,20 +1,22 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultAttr(JitTestCase):
|
||||
def test_getattr_with_default(self):
|
||||
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -22,7 +24,7 @@ class TestGetDefaultAttr(JitTestCase):
|
||||
|
||||
def forward(self, x):
|
||||
y = getattr(self, "init_attr_val") # noqa: B009
|
||||
w : list[float] = [1.0]
|
||||
w: list[float] = [1.0]
|
||||
z = getattr(self, "missing", w) # noqa: B009
|
||||
z.append(y)
|
||||
return z
|
||||
@ -32,7 +34,7 @@ class TestGetDefaultAttr(JitTestCase):
|
||||
graph = torch.jit.script(A()).graph
|
||||
|
||||
# The "init_attr_val" attribute exists
|
||||
FileCheck().check("prim::GetAttr[name=\"init_attr_val\"]").run(graph)
|
||||
FileCheck().check('prim::GetAttr[name="init_attr_val"]').run(graph)
|
||||
# The "missing" attribute does not exist, so there should be no corresponding GetAttr in AST
|
||||
FileCheck().check_not("missing").run(graph)
|
||||
# instead the getattr call will emit the default value, which is a list with one float element
|
||||
@ -46,7 +48,11 @@ class TestGetDefaultAttr(JitTestCase):
|
||||
y: torch.Tensor
|
||||
|
||||
def fn(x: MyTuple) -> Tuple[str, torch.Tensor, int]:
|
||||
return getattr(x, "x", "fdsa"), getattr(x, "y", torch.ones((3, 3))), getattr(x, "z", 7)
|
||||
return (
|
||||
getattr(x, "x", "fdsa"),
|
||||
getattr(x, "y", torch.ones((3, 3))),
|
||||
getattr(x, "z", 7),
|
||||
)
|
||||
|
||||
inp = MyTuple(x="test", y=torch.ones(3, 3) * 2)
|
||||
ref = fn(inp)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from typing import List
|
||||
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@ -119,7 +120,6 @@ class TestAutodiffJit(JitTestCase):
|
||||
self.assertEqual(y_s.requires_grad, y.requires_grad)
|
||||
self.assertEqual(z_s.requires_grad, z.requires_grad)
|
||||
|
||||
|
||||
def test_autodiff_requires_grad_nograd(self):
|
||||
@torch.jit.ignore
|
||||
def python_fn(x):
|
||||
|
||||
@ -3,26 +3,38 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, \
|
||||
num_profiled_runs, enable_profiling_mode_for_profiling_tests
|
||||
from torch.testing._internal.common_jit import check_against_reference
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_jit import check_against_reference
|
||||
from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
GRAPH_EXECUTOR,
|
||||
num_profiled_runs,
|
||||
ProfilingMode,
|
||||
)
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, disable_autodiff_subgraph_inlining
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import (
|
||||
disable_autodiff_subgraph_inlining,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
||||
)
|
||||
class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
# TODO: It is better if we can test directly on graphs instead of the current
|
||||
# end-to-end fashion.
|
||||
@ -35,11 +47,17 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
return ge.graph_for(*inputs)
|
||||
|
||||
def assertGraphSize(self, graph, size):
|
||||
nodes = list(filter(lambda n: (n.kind() != "prim::BailOut" and
|
||||
n.kind() != "prim::BailoutTemplate" and
|
||||
n.kind() != "prim::TypeCheck" and
|
||||
n.kind() != "prim::RequiresGradCheck"),
|
||||
graph.nodes()))
|
||||
nodes = list(
|
||||
filter(
|
||||
lambda n: (
|
||||
n.kind() != "prim::BailOut"
|
||||
and n.kind() != "prim::BailoutTemplate"
|
||||
and n.kind() != "prim::TypeCheck"
|
||||
and n.kind() != "prim::RequiresGradCheck"
|
||||
),
|
||||
graph.nodes(),
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(list(nodes)), size)
|
||||
|
||||
def test_chunk_constant_script_ad(self):
|
||||
@ -52,16 +70,21 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
with disable_autodiff_subgraph_inlining():
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
output = func(input, profile_and_replay=True)
|
||||
FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input))
|
||||
FileCheck().check_not("prim::DifferentiableGraph").run(
|
||||
func.graph_for(input)
|
||||
)
|
||||
|
||||
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "This threshold is only valid for Profiling Executor")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
||||
"This threshold is only valid for Profiling Executor",
|
||||
)
|
||||
def test_diff_graph_inline_threshold(self):
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
NUM_RUNS = 1
|
||||
with num_profiled_runs(NUM_RUNS):
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x):
|
||||
|
||||
# two nodes should be fused
|
||||
# see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
|
||||
return torch.sigmoid(torch.sigmoid(x))
|
||||
@ -78,12 +101,16 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
bar(input)
|
||||
bar(input)
|
||||
|
||||
self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1)
|
||||
self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
|
||||
self.assertGraphContainsExactly(
|
||||
foo.graph_for(input), "prim::DifferentiableGraph", 1
|
||||
)
|
||||
self.assertGraphContainsExactly(
|
||||
bar.graph_for(input), "prim::DifferentiableGraph", 0
|
||||
)
|
||||
|
||||
def test_bias_as_module_attr(self):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, has_bias):
|
||||
super().__init__()
|
||||
@ -99,19 +126,40 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
scripted_no_bias(x, x)
|
||||
scripted_no_bias(x, x)
|
||||
has_bias = M(True)
|
||||
check_against_reference(self, scripted_no_bias, no_bias, lambda x: x, (x, x,), check_types=False)
|
||||
check_against_reference(
|
||||
self,
|
||||
scripted_no_bias,
|
||||
no_bias,
|
||||
lambda x: x,
|
||||
(
|
||||
x,
|
||||
x,
|
||||
),
|
||||
check_types=False,
|
||||
)
|
||||
scripted_has_bias = torch.jit.script(has_bias)
|
||||
scripted_has_bias(x, x)
|
||||
scripted_has_bias(x, x)
|
||||
scripted_has_bias(x, x)
|
||||
check_against_reference(self, scripted_has_bias, has_bias, lambda x: x, (x, x,), check_types=False)
|
||||
check_against_reference(
|
||||
self,
|
||||
scripted_has_bias,
|
||||
has_bias,
|
||||
lambda x: x,
|
||||
(
|
||||
x,
|
||||
x,
|
||||
),
|
||||
check_types=False,
|
||||
)
|
||||
|
||||
def test_constructed_bias(self):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
def method1(x, weight, b1, b2):
|
||||
bias = b1 * b2
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
N = 10
|
||||
x = torch.rand(N, N, requires_grad=True)
|
||||
weight = torch.rand(N, N, requires_grad=True)
|
||||
@ -119,35 +167,58 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
b2 = torch.rand(N, N, requires_grad=True)
|
||||
scripted = self.checkScript(method1, (x, weight, b1, b2))
|
||||
# check_types requires last_graph on scripted to be set, so we just skip it
|
||||
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
|
||||
check_against_reference(
|
||||
self,
|
||||
scripted,
|
||||
method1,
|
||||
lambda x: x,
|
||||
(x, weight, b1, b2),
|
||||
check_types=False,
|
||||
)
|
||||
|
||||
def test_bias_as_arg(self):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
def method1(x, weight, bias: Optional[torch.Tensor]):
|
||||
return torch.nn.functional.linear(x, weight, bias).relu() + 2
|
||||
|
||||
N = 10
|
||||
x = torch.rand(N, N, requires_grad=True)
|
||||
weight = torch.rand(N, N, requires_grad=True)
|
||||
bias = None
|
||||
scripted = self.checkScript(method1, (x, weight, bias))
|
||||
# check_types requires last_graph on scripted to be set, so we just skip it
|
||||
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
|
||||
check_against_reference(
|
||||
self,
|
||||
scripted,
|
||||
method1,
|
||||
lambda x: x,
|
||||
(x, weight, bias),
|
||||
check_types=False,
|
||||
)
|
||||
bias = torch.rand(N, N, requires_grad=True)
|
||||
scripted = self.checkScript(method1, (x, weight, bias))
|
||||
# check_types requires last_graph on scripted to be set, so we just skip it
|
||||
check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
|
||||
check_against_reference(
|
||||
self,
|
||||
scripted,
|
||||
method1,
|
||||
lambda x: x,
|
||||
(x, weight, bias),
|
||||
check_types=False,
|
||||
)
|
||||
|
||||
def test_requires_grad_for_tensor_list(self):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
# output & var_list[0] should have requires_grad set to True
|
||||
def func(input0: torch.Tensor, input1: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
def func(
|
||||
input0: torch.Tensor, input1: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
var_list = [input0, input1]
|
||||
var = torch.cat(var_list)
|
||||
output = var + 1.0
|
||||
return output, var_list
|
||||
|
||||
jit_f = torch.jit.script(func)
|
||||
input0 = torch.randn((2,), requires_grad=True)
|
||||
input1 = torch.randn((2,))
|
||||
@ -158,12 +229,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
assert output_ref[1][0].requires_grad == output[1][0].requires_grad
|
||||
assert output_ref[1][1].requires_grad == output[1][1].requires_grad
|
||||
|
||||
@unittest.skip("disable until we property handle tensor lists with undefined gradients")
|
||||
@unittest.skip(
|
||||
"disable until we property handle tensor lists with undefined gradients"
|
||||
)
|
||||
def test_differentiable_graph_ops_requires_grad(self):
|
||||
x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
|
||||
y = torch.randn(8, 2, dtype=torch.float)
|
||||
|
||||
def t(x : torch.Tensor, y : torch.Tensor, flag : bool):
|
||||
def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
|
||||
o = x + 1.0
|
||||
o1 = torch.relu(o)
|
||||
o = y + 1.5
|
||||
@ -186,13 +259,14 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
return o1, o2, o3, oo1, oo2, oo3
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
t_jit = torch.jit.script(t)
|
||||
jit_o = t_jit(x, y, False)
|
||||
jit_o = t_jit(x, y, False)
|
||||
o = t(x, y, False)
|
||||
|
||||
FileCheck().check("prim::DifferentiableGraph").run(t_jit.graph_for(x, y, False))
|
||||
FileCheck().check("prim::DifferentiableGraph").run(
|
||||
t_jit.graph_for(x, y, False)
|
||||
)
|
||||
# validate the differentiableGraphOps are marking proper requires_grad
|
||||
for oo, jit_oo in zip(o, jit_o):
|
||||
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
|
||||
@ -204,22 +278,28 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
|
||||
self.assertEqual(oo, jit_oo)
|
||||
|
||||
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Simple Executor doesn't support gradients")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.PROFILING,
|
||||
"Simple Executor doesn't support gradients",
|
||||
)
|
||||
def test_prune_grad(self):
|
||||
@torch.jit.script
|
||||
def t(input, bias):
|
||||
return torch.nn.functional.relu(input + bias)
|
||||
|
||||
input = torch.randn(2, 8, requires_grad=True)
|
||||
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
|
||||
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
|
||||
NUM_PROFILED_RUNS = 1
|
||||
with num_profiled_runs(NUM_PROFILED_RUNS):
|
||||
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
|
||||
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
|
||||
for x in range(WARMUP):
|
||||
o = t(input, bias)
|
||||
o.sum().backward()
|
||||
|
||||
fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
|
||||
bwd_graph = list(fwd_plan.code.grad_executor_states()[0].execution_plans.values())[0].graph
|
||||
bwd_graph = list(
|
||||
fwd_plan.code.grad_executor_states()[0].execution_plans.values()
|
||||
)[0].graph
|
||||
tup = next(bwd_graph.outputs())
|
||||
self.assertEqual(len(list(tup.node().inputs())), 1)
|
||||
|
||||
@ -233,7 +313,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
||||
|
||||
self.assertGraphSize(graph, 1)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_simple_no_merge(self):
|
||||
# o: autodiff supported. x: not autodiff supported.
|
||||
@ -245,8 +325,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
||||
g_str = str(graph)
|
||||
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(g_str[0:g_str.find("return")])
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(
|
||||
g_str[0 : g_str.find("return")]
|
||||
)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_does_not_merge_unrelated(self):
|
||||
# o o
|
||||
@ -258,7 +340,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
||||
|
||||
self.assertGraphSize(graph, 3)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
||||
|
||||
def test_merges_without_cycles(self):
|
||||
# o --> o --> o
|
||||
@ -273,7 +355,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
||||
|
||||
self.assertGraphSize(graph, 1)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_merges_dense(self):
|
||||
# o o
|
||||
@ -290,7 +372,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
|
||||
|
||||
self.assertGraphSize(graph, 2)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_does_not_create_cycles(self):
|
||||
# o --> x --> o
|
||||
@ -303,7 +385,7 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
return c
|
||||
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
||||
|
||||
def test_merges_up(self):
|
||||
# o --> x o
|
||||
@ -317,8 +399,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
||||
g_str = str(graph)
|
||||
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")])
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_merges_down(self):
|
||||
# o x --> o
|
||||
@ -335,8 +417,8 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
|
||||
# add moved down
|
||||
g_str = str(graph)
|
||||
FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")])
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
||||
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
||||
|
||||
def test_respects_lexical_scoping(self):
|
||||
def fn(x, k):
|
||||
@ -346,12 +428,10 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
z = y * k
|
||||
return z, k
|
||||
|
||||
|
||||
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
|
||||
# We should not have combined the two multiplications into
|
||||
# the same group; they should each be a separate DiffGraph
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 3)
|
||||
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3)
|
||||
|
||||
def test_merge_respects_aliasing(self):
|
||||
def fn(x, k, cond):
|
||||
@ -368,15 +448,13 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
|
||||
# z2 did did not get merged into the subgraph
|
||||
FileCheck().check("prim::If").check("aten::select").check_next("aten::select")\
|
||||
.check_next("aten::add_").check("Differentiable").run(graph)
|
||||
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
||||
FileCheck().check("prim::If").check("aten::select").check_next(
|
||||
"aten::select"
|
||||
).check_next("aten::add_").check("Differentiable").run(graph)
|
||||
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
||||
|
||||
def test_aliased_outputs(self):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
|
||||
|
||||
# Case 1: aliasing between relu and t
|
||||
# is within a DifferentiableGraph. It should be valid
|
||||
# to merge both split_with_sizes in relu in one graph
|
||||
@ -389,9 +467,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = torch._C.parse_ir(input_str)
|
||||
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
||||
FileCheck().check("with prim::DifferentiableGraph") \
|
||||
.check("aten::relu").check("aten::t") \
|
||||
.run(graph)
|
||||
FileCheck().check("with prim::DifferentiableGraph").check(
|
||||
"aten::relu"
|
||||
).check("aten::t").run(graph)
|
||||
|
||||
# Case 2: aliasing between relu and split_with_sizes
|
||||
# are both outputs of a Diff graph. It should be invalid
|
||||
@ -410,11 +488,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = torch._C.parse_ir(input_str)
|
||||
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph") \
|
||||
.check("with prim::DifferentiableGraph") \
|
||||
.check("Tensor = aten::relu") \
|
||||
.check_not("aten::split_with_sizes") \
|
||||
.run(graph)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
|
||||
graph
|
||||
)
|
||||
|
||||
# Case 3: two aliased nodes in a graph.
|
||||
# Both `split_with_sizes` should be unfused
|
||||
@ -432,11 +510,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = torch._C.parse_ir(input_str)
|
||||
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph") \
|
||||
.check("with prim::DifferentiableGraph") \
|
||||
.check("Tensor = aten::relu") \
|
||||
.check_not("aten::split_with_sizes") \
|
||||
.run(graph)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
|
||||
graph
|
||||
)
|
||||
|
||||
# Case 4: the aliased output has a descendant
|
||||
# Both should be unfused. Note, %3 comes before %2
|
||||
@ -454,11 +532,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = torch._C.parse_ir(input_str)
|
||||
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph") \
|
||||
.check("with prim::DifferentiableGraph") \
|
||||
.check("Tensor = aten::relu") \
|
||||
.check_not("aten::t") \
|
||||
.run(graph)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
|
||||
|
||||
# Case 5: multiple aliased groups
|
||||
# Both should be unfused. Note, %3 comes before %2
|
||||
@ -478,11 +554,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
|
||||
graph = torch._C.parse_ir(input_str)
|
||||
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph") \
|
||||
.check("with prim::DifferentiableGraph") \
|
||||
.check("Tensor = aten::relu") \
|
||||
.check_not("aten::t") \
|
||||
.run(graph)
|
||||
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
|
||||
|
||||
def test_has_profiled_info_aliasing_outputs(self):
|
||||
# The expectation is that CallFunction will prevent the final profile node from
|
||||
@ -511,9 +585,6 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
|
||||
output = outputs[0]
|
||||
self.assertEqual(False, output.requiresGrad())
|
||||
|
||||
FileCheck().check("= prim::DifferentiableGraph") \
|
||||
.check("with prim::DifferentiableGraph") \
|
||||
.check(" = aten::relu") \
|
||||
.check("requires_grad=0") \
|
||||
.check("aten::relu") \
|
||||
.run(graph)
|
||||
FileCheck().check("= prim::DifferentiableGraph").check(
|
||||
"with prim::DifferentiableGraph"
|
||||
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.jit_utils import make_global
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._awaits import _Await as Await
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
|
||||
class TestAwait(JitTestCase):
|
||||
def test_await_python(self):
|
||||
def foo(x: int) -> int:
|
||||
return x + 13
|
||||
|
||||
aw: Await[int] = torch.jit._awaitable(foo, 13)
|
||||
self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
|
||||
nw = torch.jit._awaitable_nowait(33)
|
||||
@ -22,6 +23,7 @@ class TestAwait(JitTestCase):
|
||||
def test_await_type_python(self):
|
||||
def foo() -> Tensor:
|
||||
return torch.randn()
|
||||
|
||||
awaits = torch.jit.annotate(List[Await[Tensor]], [])
|
||||
awaits.append(torch.jit._awaitable(foo))
|
||||
|
||||
@ -82,9 +84,7 @@ class TestAwait(JitTestCase):
|
||||
self.assertTrue(torch.allclose(torch.eye(2), script_out))
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
|
||||
|
||||
def test_await_class_arg(self):
|
||||
|
||||
class C:
|
||||
def __init__(self, a: Tensor, b: Tensor):
|
||||
self.__a = a
|
||||
@ -104,6 +104,7 @@ class TestAwait(JitTestCase):
|
||||
_a = torch.eye(2)
|
||||
c2_t = torch.jit._awaitable_wait(aw)
|
||||
return _a + c2_t + x
|
||||
|
||||
inp = torch.zeros(2)
|
||||
|
||||
sm = torch.jit.script(fn)
|
||||
@ -120,7 +121,6 @@ class TestAwait(JitTestCase):
|
||||
self._a = a
|
||||
self._b = b
|
||||
|
||||
|
||||
make_global(C)
|
||||
|
||||
# Can not stay in the class as Jit does not support Recursive annotations
|
||||
@ -143,7 +143,6 @@ class TestAwait(JitTestCase):
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
|
||||
def test_await_class_return(self):
|
||||
|
||||
class C:
|
||||
__slots__ = ["a", "b"]
|
||||
|
||||
@ -151,7 +150,6 @@ class TestAwait(JitTestCase):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
|
||||
make_global(C)
|
||||
|
||||
# Can not stay in the class as Jit does not support Recursive annotations
|
||||
@ -175,7 +173,9 @@ class TestAwait(JitTestCase):
|
||||
script_out = sm(inp)
|
||||
self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=1)
|
||||
self.assertGraphContainsExactly(
|
||||
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=1
|
||||
)
|
||||
|
||||
def test_await_getattr_implicit_convertion(self):
|
||||
class C:
|
||||
@ -186,7 +186,6 @@ class TestAwait(JitTestCase):
|
||||
def b(self):
|
||||
return self._b
|
||||
|
||||
|
||||
make_global(C)
|
||||
|
||||
# Can not stay in the class as Jit does not support Recursive annotations
|
||||
@ -212,10 +211,11 @@ class TestAwait(JitTestCase):
|
||||
script_out = sm(inp)
|
||||
self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=2)
|
||||
self.assertGraphContainsExactly(
|
||||
sm.graph, kind="prim::awaitable_wait", num_kind_nodes=2
|
||||
)
|
||||
|
||||
def test_await_nested(self):
|
||||
|
||||
class C:
|
||||
def __init__(self, a: Tensor, b: Tensor):
|
||||
self.__a = a
|
||||
@ -250,6 +250,7 @@ class TestAwait(JitTestCase):
|
||||
def __init__(self, v):
|
||||
self.parent = torch.jit.annotate(Optional[Tree], None)
|
||||
self.v = v
|
||||
|
||||
make_global(Tree)
|
||||
|
||||
def delayed(t: Tree):
|
||||
@ -275,12 +276,15 @@ class TestAwait(JitTestCase):
|
||||
sm = torch.jit.script(main)
|
||||
out = main(inp)
|
||||
script_out = sm(inp)
|
||||
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
|
||||
self.assertTrue(
|
||||
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
|
||||
)
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
|
||||
def test_await_eager_lazy(self):
|
||||
def delayed(x: Tensor) -> Tensor:
|
||||
return 2 * (x + 1)
|
||||
|
||||
t = torch.ones(2, dtype=torch.int64)
|
||||
aw = torch.jit._awaitable(delayed, t)
|
||||
self.assertTrue(isinstance(aw, torch._C._Await))
|
||||
@ -302,7 +306,9 @@ class TestAwait(JitTestCase):
|
||||
|
||||
script_out_aw = sm(inp)
|
||||
script_out = torch.jit._awaitable_wait(script_out_aw)
|
||||
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
|
||||
self.assertTrue(
|
||||
torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out)
|
||||
)
|
||||
self.assertTrue(torch.allclose(script_out, out))
|
||||
|
||||
def test_jit_trace(self):
|
||||
|
||||
@ -3,10 +3,10 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from pathlib import Path
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
|
||||
|
||||
# hacky way to skip these tests in fbcode:
|
||||
@ -15,9 +15,11 @@ from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
|
||||
# it sees tests but then fails when it tries to actuall run them.
|
||||
if not IS_FBCODE:
|
||||
from test_nnapi import TestNNAPI
|
||||
|
||||
HAS_TEST_NNAPI = True
|
||||
else:
|
||||
from torch.testing._internal.common_utils import TestCase as TestNNAPI
|
||||
|
||||
HAS_TEST_NNAPI = False
|
||||
|
||||
|
||||
@ -39,10 +41,14 @@ without the delegate API.
|
||||
"""
|
||||
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
||||
torch_root = Path(__file__).resolve().parent.parent.parent
|
||||
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
|
||||
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
|
||||
|
||||
|
||||
@skipIfTorchDynamo("weird py38 failures")
|
||||
@unittest.skipIf(not os.path.exists(lib_path),
|
||||
"Skipping the test as libnnapi_backend.so was not found")
|
||||
@unittest.skipIf(
|
||||
not os.path.exists(lib_path),
|
||||
"Skipping the test as libnnapi_backend.so was not found",
|
||||
)
|
||||
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
|
||||
class TestNnapiBackend(TestNNAPI):
|
||||
def setUp(self):
|
||||
@ -89,35 +95,44 @@ method_compile_spec must use the following format:
|
||||
|
||||
# No forward key
|
||||
compile_spec = {"backward": {"inputs": args}}
|
||||
with self.assertRaisesRegex(RuntimeError, "method_compile_spec does not contain the \"forward\" key." + errorMsgTail):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
|
||||
):
|
||||
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
||||
|
||||
# No dictionary under the forward key
|
||||
compile_spec = {"forward": 1}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
|
||||
"under it's \"forward\" key."
|
||||
+ errorMsgTail):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'method_compile_spec does not contain a dictionary with an "inputs" key, '
|
||||
'under it\'s "forward" key.' + errorMsgTail,
|
||||
):
|
||||
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
||||
|
||||
# No inputs key (in the dictionary under the forward key)
|
||||
compile_spec = {"forward": {"not inputs": args}}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
|
||||
"under it's \"forward\" key."
|
||||
+ errorMsgTail):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'method_compile_spec does not contain a dictionary with an "inputs" key, '
|
||||
'under it\'s "forward" key.' + errorMsgTail,
|
||||
):
|
||||
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
||||
|
||||
# No Tensor or TensorList under the inputs key
|
||||
compile_spec = {"forward": {"inputs": 1}}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
|
||||
+ errorMsgTail):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
|
||||
+ errorMsgTail,
|
||||
):
|
||||
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
||||
compile_spec = {"forward": {"inputs": [1]}}
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
|
||||
+ errorMsgTail):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
|
||||
+ errorMsgTail,
|
||||
):
|
||||
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
@ -8,18 +7,20 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing import FileCheck
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
TEST_WITH_ROCM,
|
||||
skipIfRocm,
|
||||
find_library_location,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
@ -33,7 +34,9 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
def to_test_backend(module, method_compile_spec):
|
||||
return torch._C._jit_to_backend("test_backend", module, {"forward": method_compile_spec})
|
||||
return torch._C._jit_to_backend(
|
||||
"test_backend", module, {"forward": method_compile_spec}
|
||||
)
|
||||
|
||||
|
||||
def to_test_backend_multi(module, method_compile_spec):
|
||||
@ -63,8 +66,10 @@ class BasicModule(torch.nn.Module):
|
||||
|
||||
|
||||
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test")
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class JitBackendTestCase(JitTestCase):
|
||||
"""
|
||||
A common base class for JIT backend tests that contains common utility
|
||||
@ -73,7 +78,7 @@ class JitBackendTestCase(JitTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
lib_file_path = find_library_location('libjitbackend_test.so')
|
||||
lib_file_path = find_library_location("libjitbackend_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
# Subclasses are expected to set up three variables in their setUp methods:
|
||||
# module - a regular, Python version of the module being tested
|
||||
@ -154,13 +159,17 @@ class BasicModuleTest(JitBackendTestCase):
|
||||
self.test_execution()
|
||||
|
||||
# Save the compile spec to compare against the version retrieved after loading.
|
||||
pre_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec")
|
||||
pre_compile_spec = self.lowered_module.__getattr__(
|
||||
"__loweredModule__"
|
||||
).__getattr__("__method_compile_spec")
|
||||
|
||||
# Save and load the lowered module.
|
||||
self.save_load()
|
||||
|
||||
# Get the compile spec after loading.
|
||||
post_compile_spec = self.lowered_module.__getattr__("__loweredModule__").__getattr__("__method_compile_spec")
|
||||
post_compile_spec = self.lowered_module.__getattr__(
|
||||
"__loweredModule__"
|
||||
).__getattr__("__method_compile_spec")
|
||||
|
||||
# Compile specs should match.
|
||||
self.assertEqual(pre_compile_spec, post_compile_spec)
|
||||
@ -195,9 +204,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
|
||||
input = torch.randn(5)
|
||||
|
||||
# Test exception is thrown.
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
r"Backend is not available.",
|
||||
"raise Exception(\"Backend is not available.\""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception,
|
||||
r"Backend is not available.",
|
||||
'raise Exception("Backend is not available."',
|
||||
):
|
||||
backend_method = self.lowered_module.__getattr__("forward")
|
||||
backend_output = backend_method(*(input, input))
|
||||
|
||||
@ -207,9 +218,11 @@ class BasicModuleUnavailableTest(JitBackendTestCase):
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(self.lowered_module, buffer)
|
||||
buffer.seek(0)
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
r"Backend is not available.",
|
||||
"raise Exception(\"Backend is not available.\""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception,
|
||||
r"Backend is not available.",
|
||||
'raise Exception("Backend is not available."',
|
||||
):
|
||||
imported = torch.jit.load(buffer)
|
||||
|
||||
|
||||
@ -218,6 +231,7 @@ class NestedModuleTest(JitBackendTestCase):
|
||||
Tests for NestedModule that check that a module lowered to a backend can be used
|
||||
as a submodule.
|
||||
"""
|
||||
|
||||
class NestedModule(torch.nn.Module):
|
||||
"""
|
||||
A Module with one submodule that is used to test that lowered Modules
|
||||
@ -237,7 +251,9 @@ class NestedModuleTest(JitBackendTestCase):
|
||||
# Both modules in self.module are regular Python modules.
|
||||
self.module = NestedModuleTest.NestedModule(BasicModule())
|
||||
# Both modules in self.scripted_module are ScriptModules.
|
||||
self.scripted_module = torch.jit.script(NestedModuleTest.NestedModule(BasicModule()))
|
||||
self.scripted_module = torch.jit.script(
|
||||
NestedModuleTest.NestedModule(BasicModule())
|
||||
)
|
||||
|
||||
# First, script another instance of NestedModule with share_types=False so that it can be
|
||||
# selectively lowered without modifying the type of self.scripted_module.
|
||||
@ -246,7 +262,9 @@ class NestedModuleTest(JitBackendTestCase):
|
||||
{"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}},
|
||||
)
|
||||
# self.lowered_module is a ScriptModule, but its submodule is a lowered module.
|
||||
self.lowered_module = torch.jit.script(NestedModuleTest.NestedModule(lowered_module))
|
||||
self.lowered_module = torch.jit.script(
|
||||
NestedModuleTest.NestedModule(lowered_module)
|
||||
)
|
||||
|
||||
def test_execution(self):
|
||||
# Test execution with backend against Python and JIT.
|
||||
@ -270,6 +288,7 @@ class SelectiveLoweringTest(JitBackendTestCase):
|
||||
"""
|
||||
Tests for the selective lowering API.
|
||||
"""
|
||||
|
||||
class OuterModule(torch.nn.Module):
|
||||
def __init__(self, sub1, sub2, other):
|
||||
super().__init__()
|
||||
@ -299,7 +318,10 @@ class SelectiveLoweringTest(JitBackendTestCase):
|
||||
MiddleModule = SelectiveLoweringTest.MiddleModule
|
||||
|
||||
def script_without_type_sharing(mod):
|
||||
return torch.jit._recursive.create_script_module(mod, torch.jit._recursive.infer_methods_to_compile, share_types=False)
|
||||
return torch.jit._recursive.create_script_module(
|
||||
mod, torch.jit._recursive.infer_methods_to_compile, share_types=False
|
||||
)
|
||||
|
||||
# Create Python, JIT and backend versions of a hierarchy that looks like this:
|
||||
# --------- OuterModule --------
|
||||
# | | |
|
||||
@ -308,13 +330,28 @@ class SelectiveLoweringTest(JitBackendTestCase):
|
||||
# BasicModule BasicModule BasicModule
|
||||
#
|
||||
# Two BasicModules will be lowered and the third will not.
|
||||
self.module = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))
|
||||
self.scripted_module = script_without_type_sharing(OuterModule(MiddleModule(
|
||||
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())))
|
||||
self.lowered_module = script_without_type_sharing(OuterModule(MiddleModule(
|
||||
BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule())))
|
||||
self.lowered_module = to_test_backend_selective(self.lowered_module, {"forward": ""}, [
|
||||
"sub1.submodule", "sub2.submodule"])
|
||||
self.module = OuterModule(
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
)
|
||||
self.scripted_module = script_without_type_sharing(
|
||||
OuterModule(
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
)
|
||||
)
|
||||
self.lowered_module = script_without_type_sharing(
|
||||
OuterModule(
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
)
|
||||
)
|
||||
self.lowered_module = to_test_backend_selective(
|
||||
self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"]
|
||||
)
|
||||
|
||||
def test_execution(self):
|
||||
input = torch.randn(5)
|
||||
@ -335,93 +372,93 @@ class SelectiveLoweringTest(JitBackendTestCase):
|
||||
"""
|
||||
# Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it
|
||||
# calling the lowered module directly.
|
||||
FileCheck() \
|
||||
.check("OuterModule") \
|
||||
.check("BasicModule") \
|
||||
.run(self.scripted_module.graph)
|
||||
FileCheck() \
|
||||
.check("OuterModule") \
|
||||
.check_not("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.check("LoweredWrapper.test_backend") \
|
||||
.run(self.lowered_module.graph)
|
||||
FileCheck().check("OuterModule").check("BasicModule").run(
|
||||
self.scripted_module.graph
|
||||
)
|
||||
FileCheck().check("OuterModule").check_not(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).check("LoweredWrapper.test_backend").run(self.lowered_module.graph)
|
||||
|
||||
# Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs.
|
||||
FileCheck() \
|
||||
.check("MiddleModule") \
|
||||
.check("BasicModule") \
|
||||
.check_not("LoweredWrapper.test_backend") \
|
||||
.run(self.scripted_module.sub1.graph)
|
||||
FileCheck() \
|
||||
.check("MiddleModule") \
|
||||
.check_not("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.check("LoweredWrapper.test_backend") \
|
||||
.run(self.lowered_module.sub1.graph)
|
||||
FileCheck().check("MiddleModule").check("BasicModule").check_not(
|
||||
"LoweredWrapper.test_backend"
|
||||
).run(self.scripted_module.sub1.graph)
|
||||
FileCheck().check("MiddleModule").check_not(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph)
|
||||
|
||||
FileCheck() \
|
||||
.check("MiddleModule") \
|
||||
.check("BasicModule") \
|
||||
.check_not("LoweredWrapper.test_backend") \
|
||||
.run(self.scripted_module.sub2.graph)
|
||||
FileCheck() \
|
||||
.check("MiddleModule") \
|
||||
.check_not("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.check("LoweredWrapper.test_backend") \
|
||||
.run(self.lowered_module.sub2.graph)
|
||||
FileCheck().check("MiddleModule").check("BasicModule").check_not(
|
||||
"LoweredWrapper.test_backend"
|
||||
).run(self.scripted_module.sub2.graph)
|
||||
FileCheck().check("MiddleModule").check_not(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph)
|
||||
|
||||
# Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute
|
||||
# __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend,
|
||||
# the TorchBind class for executing functions on the test JIT backend.
|
||||
FileCheck() \
|
||||
.check("LoweredModule.test_backend") \
|
||||
.check("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
|
||||
FileCheck().check("LoweredModule.test_backend").check(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).run(self.lowered_module.sub1.submodule.__loweredModule__.graph)
|
||||
|
||||
FileCheck() \
|
||||
.check("LoweredModule.test_backend") \
|
||||
.check("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
|
||||
FileCheck().check("LoweredModule.test_backend").check(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).run(self.lowered_module.sub2.submodule.__loweredModule__.graph)
|
||||
|
||||
# Check that self.other and self.other.submodule have been left untouched by the selective lowering process.
|
||||
FileCheck() \
|
||||
.check("MiddleModule") \
|
||||
.check("BasicModule") \
|
||||
.check_not("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.check_not("LoweredWrapper.test_backend") \
|
||||
.run(self.scripted_module.other.graph)
|
||||
FileCheck() \
|
||||
.check("BasicModule") \
|
||||
.check_not("__torch__.torch.classes.__backends__.test_backend") \
|
||||
.check_not("LoweredModule.test_backend") \
|
||||
.run(self.scripted_module.other.submodule.graph)
|
||||
FileCheck().check("MiddleModule").check("BasicModule").check_not(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph)
|
||||
FileCheck().check("BasicModule").check_not(
|
||||
"__torch__.torch.classes.__backends__.test_backend"
|
||||
).check_not("LoweredModule.test_backend").run(
|
||||
self.scripted_module.other.submodule.graph
|
||||
)
|
||||
|
||||
def test_errors(self):
|
||||
"""
|
||||
Check errors associated with selective lowering.
|
||||
"""
|
||||
# Check error messages thrown when attempting to lower something that is not a ScriptModule.
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Object .* is not a ScriptModule", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"Object .* is not a ScriptModule", ""
|
||||
):
|
||||
to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"])
|
||||
|
||||
MiddleModule = SelectiveLoweringTest.MiddleModule
|
||||
mod = MiddleModule(BasicModule())
|
||||
mod.new_attr = 3
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute named new_attr is not a Module", ""):
|
||||
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["new_attr"])
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"Attribute named new_attr is not a Module", ""
|
||||
):
|
||||
to_test_backend_selective(
|
||||
torch.jit.script(mod), {"forward": ""}, ["new_attr"]
|
||||
)
|
||||
|
||||
# Check error message thrown when module hierarchy doesn't have unique types.
|
||||
OuterModule = SelectiveLoweringTest.OuterModule
|
||||
mod = OuterModule(MiddleModule(BasicModule()), MiddleModule(BasicModule()), MiddleModule(BasicModule()))
|
||||
mod = OuterModule(
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
MiddleModule(BasicModule()),
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
r"Selective lowering is only supported for module hierarchies with unique types",
|
||||
""):
|
||||
to_test_backend_selective(torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"])
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
r"Selective lowering is only supported for module hierarchies with unique types",
|
||||
"",
|
||||
):
|
||||
to_test_backend_selective(
|
||||
torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"]
|
||||
)
|
||||
|
||||
|
||||
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test")
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class TestBackends(JitTestCase):
|
||||
"""
|
||||
This class wraps and invokes all subclasses of JitBackendTestCase so that each one
|
||||
@ -461,6 +498,7 @@ class TestBackends(JitTestCase):
|
||||
def test_errors(self):
|
||||
self.selective_lowering_test.test_errors()
|
||||
|
||||
|
||||
"""
|
||||
Unit Tests for backend with compiler
|
||||
This test case and the existing TestBackends are separate because they cover different aspects.
|
||||
@ -468,6 +506,8 @@ The actual backend implementation in this test is different.
|
||||
It has a simple demo compiler to test the end-to-end flow in mobile.
|
||||
However, this test cannot cover the selective_lowering for now, which is covered in TestBackends.
|
||||
"""
|
||||
|
||||
|
||||
class BasicModuleAdd(torch.nn.Module):
|
||||
"""
|
||||
A simple add Module used to test to_backend lowering machinery.
|
||||
@ -476,9 +516,12 @@ class BasicModuleAdd(torch.nn.Module):
|
||||
def forward(self, x, h):
|
||||
return x + h
|
||||
|
||||
|
||||
# This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends.
|
||||
@unittest.skipIf(TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test")
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class JitBackendTestCaseWithCompiler(JitTestCase):
|
||||
"""
|
||||
A common base class for JIT backend tests with compilers that contains common utility
|
||||
@ -487,7 +530,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
lib_file_path = find_library_location('libbackend_with_compiler.so')
|
||||
lib_file_path = find_library_location("libbackend_with_compiler.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
# Subclasses are expected to set up four variables in their setUp methods:
|
||||
# module - a regular, Python version of the module being tested
|
||||
@ -524,6 +567,7 @@ class JitBackendTestCaseWithCompiler(JitTestCase):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
|
||||
"""
|
||||
Tests for BasicModuleAdd.
|
||||
@ -541,7 +585,8 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
|
||||
},
|
||||
}
|
||||
self.lowered_module = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo", self.scripted_module, compile_spec)
|
||||
"backend_with_compiler_demo", self.scripted_module, compile_spec
|
||||
)
|
||||
# Create mobile version of BasicModuleAdd
|
||||
buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -552,6 +597,7 @@ class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler):
|
||||
input = torch.ones(1, dtype=torch.float)
|
||||
self.check_forward((input, input))
|
||||
|
||||
|
||||
class ErrorMessagesWithCompiler(JitBackendTestCase):
|
||||
"""
|
||||
Tests for errors that occur with compiler, specifically:
|
||||
@ -562,22 +608,31 @@ class ErrorMessagesWithCompiler(JitBackendTestCase):
|
||||
"""
|
||||
A module with an operator that is not supported.
|
||||
"""
|
||||
|
||||
def forward(self, x, h):
|
||||
return x * h
|
||||
self._loweredmodule.forward()
|
||||
|
||||
def test_errors(self):
|
||||
scripted_module_n = torch.jit.script(ErrorMessagesWithCompiler.ModuleNotSupported())
|
||||
scripted_module_n = torch.jit.script(
|
||||
ErrorMessagesWithCompiler.ModuleNotSupported()
|
||||
)
|
||||
# Test exception is thrown when lowering a module with an unsupported operator
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
# Special escape characters are replaced with '.'
|
||||
r"""The node of aten::mul is not supported in this compiler. .*
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
# Special escape characters are replaced with '.'
|
||||
r"""The node of aten::mul is not supported in this compiler. .*
|
||||
def forward.self, x, h.:
|
||||
return x . h
|
||||
~~~~~ <--- HERE
|
||||
self._loweredmodule.forward..
|
||||
""", ""):
|
||||
lowered_module_n = torch._C._jit_to_backend("backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}})
|
||||
""",
|
||||
"",
|
||||
):
|
||||
lowered_module_n = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}}
|
||||
)
|
||||
|
||||
|
||||
class CompModuleTestWithCompiler(JitBackendTestCase):
|
||||
"""
|
||||
@ -588,6 +643,7 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
|
||||
"""
|
||||
A simple subtraction Module to be used in CompModule.
|
||||
"""
|
||||
|
||||
def forward(self, x, h):
|
||||
return x - h
|
||||
|
||||
@ -617,14 +673,19 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
|
||||
},
|
||||
}
|
||||
lowered_add = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo", torch.jit.script(BasicModuleAdd()), compile_spec)
|
||||
"backend_with_compiler_demo",
|
||||
torch.jit.script(BasicModuleAdd()),
|
||||
compile_spec,
|
||||
)
|
||||
lowered_sub = torch._C._jit_to_backend(
|
||||
"backend_with_compiler_demo",
|
||||
torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()),
|
||||
{"forward": {"": ""}}
|
||||
{"forward": {"": ""}},
|
||||
)
|
||||
self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
|
||||
self.scripted_module = torch.jit.script(CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub))
|
||||
self.scripted_module = torch.jit.script(
|
||||
CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub)
|
||||
)
|
||||
# No backend version of CompModule currently, so this is filler.
|
||||
self.lowered_module = self.scripted_module
|
||||
# Create a mobile version of CompModule from JIT version
|
||||
@ -640,9 +701,12 @@ class CompModuleTestWithCompiler(JitBackendTestCase):
|
||||
# Test forward.
|
||||
self.check_function("forward", (input1, input2, input2))
|
||||
|
||||
|
||||
# This is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
||||
@unittest.skipIf(IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test")
|
||||
@unittest.skipIf(
|
||||
IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE,
|
||||
"Non-portable load_library call used in test",
|
||||
)
|
||||
class TestBackendsWithCompiler(JitTestCase):
|
||||
"""
|
||||
This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler
|
||||
@ -711,7 +775,6 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
|
||||
y = s * (c * d)
|
||||
return y
|
||||
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@ -728,6 +791,7 @@ class CompModuleTestSameNameWithCompiler(JitBackendTestCase):
|
||||
# Test forward.
|
||||
self.check_function("forward", (a, b, s))
|
||||
|
||||
|
||||
class AddedAttributesTest(JitBackendTestCase):
|
||||
"""
|
||||
Tests for adding attributes to a model after lowering.
|
||||
@ -747,11 +811,19 @@ class AddedAttributesTest(JitBackendTestCase):
|
||||
input = [(torch.ones(5),)]
|
||||
pre_bundled = self.lowered_module(*input[0])
|
||||
# Attach bundled inputs which adds several attributes and functions to the model
|
||||
self.lowered_module = torch.utils.bundled_inputs.augment_model_with_bundled_inputs(lowered_module, input) # noqa: F821
|
||||
post_bundled = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
|
||||
self.lowered_module = (
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
lowered_module, input # noqa: F821
|
||||
)
|
||||
)
|
||||
post_bundled = self.lowered_module(
|
||||
*self.lowered_module.get_all_bundled_inputs()[0]
|
||||
)
|
||||
# Save and load the lowered module.
|
||||
self.save_load()
|
||||
# Use bundled after save and load to prove its preserved
|
||||
post_load = self.lowered_module(*self.lowered_module.get_all_bundled_inputs()[0])
|
||||
post_load = self.lowered_module(
|
||||
*self.lowered_module.get_all_bundled_inputs()[0]
|
||||
)
|
||||
self.assertEqual(pre_bundled, post_bundled)
|
||||
self.assertEqual(post_bundled, post_load)
|
||||
|
||||
@ -56,7 +56,6 @@ class TestBatchMM(JitTestCase):
|
||||
actual = test_batch_mm_scripted(*tensors)
|
||||
self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
|
||||
|
||||
|
||||
def test_batch_mm_permitted_mutation(self):
|
||||
def test_batch_mm(
|
||||
T1: torch.Tensor,
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
|
||||
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestBuiltins(JitTestCase):
|
||||
@ -86,24 +88,27 @@ class TestBuiltins(JitTestCase):
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
a = x ** 2
|
||||
a = x**2
|
||||
del a
|
||||
return a # noqa: F821
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "a"):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
a = x ** 2
|
||||
a = x**2
|
||||
if a:
|
||||
del a
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "undefined value", "b"):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
a = x ** 2
|
||||
a = x**2
|
||||
del b # noqa: F821
|
||||
return a
|
||||
|
||||
@ -124,7 +129,7 @@ class TestBuiltins(JitTestCase):
|
||||
self.assertEqual(py_out, jit_out)
|
||||
|
||||
def del_dict_multiple_operands(x: Dict[str, int]) -> Dict[str, int]:
|
||||
del x['hi'], x['there']
|
||||
del x["hi"], x["there"]
|
||||
return x
|
||||
|
||||
py_out = del_dict_multiple_operands({"hi": 5, "there": 6})
|
||||
@ -137,7 +142,7 @@ class TestTensorBuiltins(JitTestCase):
|
||||
def should_keep(tensor, name):
|
||||
if inspect.isroutine(getattr(tensor, name)):
|
||||
return False
|
||||
if name.startswith('_'):
|
||||
if name.startswith("_"):
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -145,8 +150,8 @@ class TestTensorBuiltins(JitTestCase):
|
||||
keys = dir(tensor)
|
||||
|
||||
# real and imag are only implemented for complex tensors.
|
||||
self.assertRaises(RuntimeError, lambda: should_keep(tensor, 'imag'))
|
||||
keys.remove('imag')
|
||||
self.assertRaises(RuntimeError, lambda: should_keep(tensor, "imag"))
|
||||
keys.remove("imag")
|
||||
|
||||
properties = [p for p in keys if should_keep(tensor, p)]
|
||||
|
||||
@ -158,16 +163,16 @@ class TestTensorBuiltins(JitTestCase):
|
||||
EQUALITY_MISMATCH = {
|
||||
# TorchScript doesn't have real enums so they return an int instead
|
||||
# of the actual value
|
||||
'dtype',
|
||||
'layout',
|
||||
"dtype",
|
||||
"layout",
|
||||
}
|
||||
MISSING_PROPERTIES = {
|
||||
'grad_fn',
|
||||
"grad_fn",
|
||||
# This is an undocumented property so it's not included
|
||||
"output_nr",
|
||||
# This has a longer implementation, maybe not worth copying to
|
||||
# TorchScript if named tensors don't work there anyways
|
||||
'names',
|
||||
"names",
|
||||
}
|
||||
|
||||
for p in properties:
|
||||
@ -232,7 +237,8 @@ class TestTensorBuiltins(JitTestCase):
|
||||
def func():
|
||||
c = 1
|
||||
return c.add(1)
|
||||
with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"):
|
||||
torch.jit.script(func)
|
||||
|
||||
# testing implicit conversion of tensors to scalars to match function arguments
|
||||
@ -265,10 +271,12 @@ class TestTensorBuiltins(JitTestCase):
|
||||
|
||||
x = torch.zeros(10)
|
||||
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
|
||||
tensors = [torch.tensor(1.1),
|
||||
torch.tensor(1.1, requires_grad=True),
|
||||
torch.tensor(0),
|
||||
torch.tensor([2])]
|
||||
tensors = [
|
||||
torch.tensor(1.1),
|
||||
torch.tensor(1.1, requires_grad=True),
|
||||
torch.tensor(0),
|
||||
torch.tensor([2]),
|
||||
]
|
||||
|
||||
script_funs = [tensor_to_int_script, tensor_to_float_script]
|
||||
funs = [tensor_to_int, tensor_to_float]
|
||||
@ -286,4 +294,6 @@ class TestTensorBuiltins(JitTestCase):
|
||||
# assert result or exception equal for each (function, inputs)
|
||||
for tensor in tensors:
|
||||
for i in range(len(script_funs)):
|
||||
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
|
||||
self.assertEqual(
|
||||
test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor)
|
||||
)
|
||||
|
||||
@ -4,24 +4,28 @@ import io
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing import FileCheck
|
||||
from typing import Any
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo
|
||||
from typing import List, Tuple, Iterable, Optional, Dict
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestClassType(JitTestCase):
|
||||
def test_reference_semantics(self):
|
||||
@ -29,6 +33,7 @@ class TestClassType(JitTestCase):
|
||||
Test that modifications made to a class instance in TorchScript
|
||||
are visible in eager.
|
||||
"""
|
||||
|
||||
class Foo:
|
||||
def __init__(self, a: int):
|
||||
self.a = a
|
||||
@ -92,12 +97,12 @@ class TestClassType(JitTestCase):
|
||||
pass
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key == 'hi'
|
||||
return key == "hi"
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
foo = FooTest()
|
||||
return 'hi' in foo, 'no' in foo
|
||||
return "hi" in foo, "no" in foo
|
||||
|
||||
self.assertEqual(fn(), (True, False))
|
||||
|
||||
@ -118,7 +123,10 @@ class TestClassType(JitTestCase):
|
||||
self.assertEqual(fn(1), 3)
|
||||
|
||||
def test_set_attr_type_mismatch(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Wrong type for attribute assignment", "self.foo = 10"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
@ -126,7 +134,10 @@ class TestClassType(JitTestCase):
|
||||
self.foo = 10 # should error since int != Tensor
|
||||
|
||||
def test_get_attr_not_initialized(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.asdf"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "object has no attribute or method", "self.asdf"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
@ -136,7 +147,10 @@ class TestClassType(JitTestCase):
|
||||
return self.asdf # asdf isn't an attr
|
||||
|
||||
def test_set_attr_non_initialized(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.bar = y"
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
@ -153,12 +167,16 @@ class TestClassType(JitTestCase):
|
||||
Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'.
|
||||
"""
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""):
|
||||
|
||||
@torch.jit.script
|
||||
def FooTest(x):
|
||||
return torch.nn.functional.interpolate(x, 'bad')
|
||||
return torch.nn.functional.interpolate(x, "bad")
|
||||
|
||||
def test_type_annotations(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Expected a value of type \'bool", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Expected a value of type 'bool", ""
|
||||
):
|
||||
|
||||
@torch.jit.script # noqa: B903
|
||||
class FooTest: # noqa: B903
|
||||
def __init__(self, x: bool) -> None:
|
||||
@ -171,7 +189,10 @@ class TestClassType(JitTestCase):
|
||||
fn(2)
|
||||
|
||||
def test_conditional_set_attr(self):
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "assignment cannot be in a control-flow block", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "assignment cannot be in a control-flow block", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class FooTest:
|
||||
def __init__(self, x):
|
||||
@ -236,7 +257,6 @@ class TestClassType(JitTestCase):
|
||||
# classes are globally registered for now, so we need to clear the JIT
|
||||
# registry to simulate loading a new model
|
||||
|
||||
|
||||
buffer.seek(0)
|
||||
m_loaded = torch.jit.load(buffer)
|
||||
|
||||
@ -320,7 +340,7 @@ class TestClassType(JitTestCase):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
def use_foo(foo: Foo) -> Foo:
|
||||
@ -419,15 +439,22 @@ class TestClassType(JitTestCase):
|
||||
|
||||
self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)])
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "bool\' for argument \'reverse", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "bool' for argument 'reverse", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
li = [Foo(1)]
|
||||
li.sort(li)
|
||||
return li
|
||||
|
||||
test()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "must define a __lt__", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class NoMethod:
|
||||
def __init__(self):
|
||||
@ -438,6 +465,7 @@ class TestClassType(JitTestCase):
|
||||
li = [NoMethod(), NoMethod()]
|
||||
li.sort()
|
||||
return li
|
||||
|
||||
test()
|
||||
|
||||
@torch.jit.script
|
||||
@ -449,12 +477,16 @@ class TestClassType(JitTestCase):
|
||||
def __lt__(self, other):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "must define a __lt__", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "must define a __lt__", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
li = [WrongLt(), WrongLt()]
|
||||
li.sort()
|
||||
return li
|
||||
|
||||
test()
|
||||
|
||||
def test_class_inheritance(self):
|
||||
@ -466,18 +498,21 @@ class TestClassType(JitTestCase):
|
||||
def two(self, x):
|
||||
return x + self.b
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not support inheritance", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "does not support inheritance", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
class Derived(Base):
|
||||
def two(self, x):
|
||||
return x + self.b + 2
|
||||
|
||||
|
||||
def test_class_inheritance_implicit(self):
|
||||
"""
|
||||
Test that inheritance is detected in
|
||||
implicit scripting codepaths (e.g. try_ann_to_type).
|
||||
"""
|
||||
|
||||
class A:
|
||||
def __init__(self, t):
|
||||
self.t = t
|
||||
@ -502,14 +537,16 @@ class TestClassType(JitTestCase):
|
||||
else:
|
||||
return B.f(x.t)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "object has no attribute or method", ""
|
||||
):
|
||||
sc = torch.jit.script(fun)
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
@unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode")
|
||||
def test_imported_classes(self):
|
||||
import jit._imported_class_test.foo
|
||||
import jit._imported_class_test.bar
|
||||
import jit._imported_class_test.foo
|
||||
import jit._imported_class_test.very.very.nested
|
||||
|
||||
class MyMod(torch.jit.ScriptModule):
|
||||
@ -593,6 +630,7 @@ class TestClassType(JitTestCase):
|
||||
|
||||
def one(self, x, y):
|
||||
return x + y
|
||||
|
||||
# missing two
|
||||
|
||||
@torch.jit.script
|
||||
@ -616,6 +654,7 @@ class TestClassType(JitTestCase):
|
||||
x = c[i].one(x, x)
|
||||
x = c[i].two(x)
|
||||
return x
|
||||
|
||||
self.checkScript(use_them, (torch.rand(3, 4),))
|
||||
|
||||
@torch.jit.script
|
||||
@ -626,22 +665,33 @@ class TestClassType(JitTestCase):
|
||||
def inherit(x: OneTwoThree) -> OneTwo:
|
||||
return as_interface(x)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "does not have method", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def wrong1():
|
||||
return as_interface(NotMember())
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "is not compatible with interface", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def wrong2():
|
||||
return as_interface(NotMember2())
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "does not have method", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "does not have method", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def wrong3():
|
||||
return inherit(as_interface(Foo()))
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "is not compatible with interface", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "is not compatible with interface", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def wrong4(x: OneTwoWrong) -> int:
|
||||
@ -656,7 +706,7 @@ class TestClassType(JitTestCase):
|
||||
def forward(self, x):
|
||||
return self.proxy_mod.two(x)
|
||||
|
||||
TestPyAssign.__annotations__ = {'proxy_mod': OneTwo}
|
||||
TestPyAssign.__annotations__ = {"proxy_mod": OneTwo}
|
||||
|
||||
input = torch.rand(3, 4)
|
||||
scripted_pyassign_mod = torch.jit.script(TestPyAssign())
|
||||
@ -671,10 +721,11 @@ class TestClassType(JitTestCase):
|
||||
def forward(self, x):
|
||||
return self.proxy_mod.two(x)
|
||||
|
||||
TestPyAssignError.__annotations__ = {'proxy_mod': OneTwoThree}
|
||||
TestPyAssignError.__annotations__ = {"proxy_mod": OneTwoThree}
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"is not compatible with interface __torch__", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "is not compatible with interface __torch__", ""
|
||||
):
|
||||
torch.jit.script(TestPyAssignError(Foo()))
|
||||
|
||||
# test pure python object assignment to interface fails
|
||||
@ -682,8 +733,9 @@ class TestClassType(JitTestCase):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"the value is not a TorchScript compatible type", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "the value is not a TorchScript compatible type", ""
|
||||
):
|
||||
torch.jit.script(TestPyAssignError(PyClass()))
|
||||
# TODO test: interface-interface class-interface inheritance errors,
|
||||
# NamedTuple inheritance errors
|
||||
@ -729,7 +781,7 @@ class TestClassType(JitTestCase):
|
||||
return self.x * other
|
||||
|
||||
def __pow__(self, other: int) -> int:
|
||||
return int(self.x ** other)
|
||||
return int(self.x**other)
|
||||
|
||||
def __truediv__(self, other: int) -> float:
|
||||
return self.x / other
|
||||
@ -773,54 +825,89 @@ class TestClassType(JitTestCase):
|
||||
def __call__(self, val: int) -> int:
|
||||
return self.x * val * 3
|
||||
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
def add():
|
||||
return MyClass(4) + 3
|
||||
|
||||
def sub(): # noqa: E306
|
||||
return MyClass(4) - 3
|
||||
|
||||
def mul(): # noqa: E306
|
||||
return MyClass(4) * 3
|
||||
|
||||
def pow(): # noqa: E306
|
||||
return MyClass(4) ** 3
|
||||
|
||||
def truediv(): # noqa: E306
|
||||
return MyClass(4) / 3
|
||||
|
||||
def ne(): # noqa: E306
|
||||
return MyClass(4) != 3
|
||||
|
||||
def eq(): # noqa: E306
|
||||
return MyClass(4) == 3
|
||||
|
||||
def lt(): # noqa: E306
|
||||
return MyClass(4) < 3
|
||||
|
||||
def gt(): # noqa: E306
|
||||
return MyClass(4) > 3
|
||||
|
||||
def le(): # noqa: E306
|
||||
return MyClass(4) <= 3
|
||||
|
||||
def ge(): # noqa: E306
|
||||
return MyClass(4) >= 3
|
||||
|
||||
def _and(): # noqa: E306
|
||||
return MyClass(4) & 3
|
||||
|
||||
def _or(): # noqa: E306
|
||||
return MyClass(4) | 3
|
||||
|
||||
def _xor(): # noqa: E306
|
||||
return MyClass(4) ^ 3
|
||||
|
||||
def getitem(): # noqa: E306
|
||||
return MyClass(4)[1]
|
||||
|
||||
def setitem(): # noqa: E306
|
||||
a = MyClass(4)
|
||||
a[1] = 5
|
||||
return a.x
|
||||
|
||||
def call(): # noqa: E306
|
||||
a = MyClass(5)
|
||||
return a(2)
|
||||
|
||||
ops = [add, sub, mul, pow, ne, eq, lt, gt, le, ge, _and, _or, _xor, getitem, setitem, call]
|
||||
ops = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
pow,
|
||||
ne,
|
||||
eq,
|
||||
lt,
|
||||
gt,
|
||||
le,
|
||||
ge,
|
||||
_and,
|
||||
_or,
|
||||
_xor,
|
||||
getitem,
|
||||
setitem,
|
||||
call,
|
||||
]
|
||||
|
||||
ops.append(truediv)
|
||||
for func in ops:
|
||||
self.checkScript(func, ())
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "object has no attribute or method", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
|
||||
@ -852,7 +939,7 @@ class TestClassType(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test)
|
||||
self.assertEqual(fn(Foo(0.5)), test(0.5))
|
||||
self.assertEqual(fn(Foo(0.)), test(0.0))
|
||||
self.assertEqual(fn(Foo(0.0)), test(0.0))
|
||||
# str has slightly different formatting
|
||||
self.assertTrue("0.5" in (str(Foo(0.5))))
|
||||
self.assertTrue("0." in (str(Foo(0.0))))
|
||||
@ -865,7 +952,10 @@ class TestClassType(JitTestCase):
|
||||
def __bool__(self):
|
||||
return (1, 2)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "expected a bool expression for condition", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "expected a bool expression for condition", ""
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def test():
|
||||
if BadBool():
|
||||
@ -921,6 +1011,7 @@ class TestClassType(JitTestCase):
|
||||
Recursive class types not yet supported. We should give a good error message.
|
||||
"""
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@torch.jit.script # noqa: B903
|
||||
class Tree: # noqa: B903
|
||||
def __init__(self):
|
||||
@ -940,7 +1031,7 @@ class TestClassType(JitTestCase):
|
||||
return x, y
|
||||
|
||||
# Test serialization/deserialization of class constant
|
||||
for c in (2, 1.0, None, True, 'str', (2, 3), [5.9, 7.3]):
|
||||
for c in (2, 1.0, None, True, "str", (2, 3), [5.9, 7.3]):
|
||||
m = torch.jit.script(M(c))
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(m, buffer)
|
||||
@ -954,28 +1045,31 @@ class TestClassType(JitTestCase):
|
||||
|
||||
def test_py_class_to_ivalue_missing_attribute(self):
|
||||
class Foo:
|
||||
i : int
|
||||
f : float
|
||||
i: int
|
||||
f: float
|
||||
|
||||
def __init__(self, i : int, f : float):
|
||||
def __init__(self, i: int, f: float):
|
||||
self.i = i
|
||||
self.f = f
|
||||
|
||||
make_global(Foo) # see [local resolution in python]
|
||||
|
||||
@torch.jit.script
|
||||
def test_fn(x : Foo) -> float:
|
||||
def test_fn(x: Foo) -> float:
|
||||
return x.i + x.f
|
||||
|
||||
test_fn(Foo(3, 4.0))
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, 'missing attribute i', ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "missing attribute i", ""
|
||||
):
|
||||
test_fn(torch.rand(3, 4))
|
||||
|
||||
def test_unused_method(self):
|
||||
"""
|
||||
Test unused methods on scripted classes.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class Unused:
|
||||
def __init__(self):
|
||||
@ -1028,12 +1122,13 @@ class TestClassType(JitTestCase):
|
||||
Test that a scripted class can have a method that refers to the class itself
|
||||
in its type annotations.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class Meta:
|
||||
def __init__(self, a: int):
|
||||
self.a = a
|
||||
|
||||
def method(self, other: List['Meta']) -> 'Meta':
|
||||
def method(self, other: List["Meta"]) -> "Meta":
|
||||
return Meta(len(other))
|
||||
|
||||
class ModuleWithMeta(torch.nn.Module):
|
||||
@ -1051,19 +1146,20 @@ class TestClassType(JitTestCase):
|
||||
"""
|
||||
Test that annotating container attributes with types works correctly
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class CompetitiveLinkingTokenReplacementUtils:
|
||||
def __init__(self):
|
||||
self.my_list : List[Tuple[float, int, int]] = []
|
||||
self.my_dict : Dict[int, int] = {}
|
||||
self.my_list: List[Tuple[float, int, int]] = []
|
||||
self.my_dict: Dict[int, int] = {}
|
||||
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
y = CompetitiveLinkingTokenReplacementUtils()
|
||||
new_dict : Dict[int, int] = {1: 1, 2: 2}
|
||||
new_dict: Dict[int, int] = {1: 1, 2: 2}
|
||||
y.my_dict = new_dict
|
||||
|
||||
new_list : List[Tuple[float, int, int]] = [(1.0, 1, 1)]
|
||||
new_list: List[Tuple[float, int, int]] = [(1.0, 1, 1)]
|
||||
y.my_list = new_list
|
||||
return y
|
||||
|
||||
@ -1071,6 +1167,7 @@ class TestClassType(JitTestCase):
|
||||
"""
|
||||
Test that methods on class types can have default arguments.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class ClassWithDefaultArgs:
|
||||
def __init__(
|
||||
@ -1105,7 +1202,9 @@ class TestClassType(JitTestCase):
|
||||
return obj.int + obj.list[2] + obj.dict[1]
|
||||
|
||||
def override_defaults() -> int:
|
||||
obj: ClassWithDefaultArgs = ClassWithDefaultArgs(3, [9, 10, 11], (12, 13, 14), {3: 4}, "str")
|
||||
obj: ClassWithDefaultArgs = ClassWithDefaultArgs(
|
||||
3, [9, 10, 11], (12, 13, 14), {3: 4}, "str"
|
||||
)
|
||||
s: int = obj.int
|
||||
|
||||
for x in obj.list:
|
||||
@ -1154,7 +1253,7 @@ class TestClassType(JitTestCase):
|
||||
|
||||
# The constructor of this class below has mutable arguments. This should throw
|
||||
# an error.
|
||||
class ClassWithMutableArgs: # noqa: B903
|
||||
class ClassWithMutableArgs: # noqa: B903
|
||||
def __init__(
|
||||
self,
|
||||
a: List[int] = [1, 2, 3], # noqa: B006
|
||||
@ -1164,13 +1263,16 @@ class TestClassType(JitTestCase):
|
||||
def should_fail():
|
||||
obj: ClassWithMutableArgs = ClassWithMutableArgs()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Mutable default parameters are not supported", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Mutable default parameters are not supported", ""
|
||||
):
|
||||
torch.jit.script(should_fail)
|
||||
|
||||
def test_staticmethod(self):
|
||||
"""
|
||||
Test static methods on class types.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class ClassWithStaticMethod:
|
||||
def __init__(self, a: int, b: int):
|
||||
@ -1183,22 +1285,22 @@ class TestClassType(JitTestCase):
|
||||
def get_b(self):
|
||||
return self.b
|
||||
|
||||
def __eq__(self, other: 'ClassWithStaticMethod'):
|
||||
def __eq__(self, other: "ClassWithStaticMethod"):
|
||||
return self.a == other.a and self.b == other.b
|
||||
|
||||
# staticmethod that calls constructor.
|
||||
@staticmethod
|
||||
def create(args: List['ClassWithStaticMethod']) -> 'ClassWithStaticMethod':
|
||||
def create(args: List["ClassWithStaticMethod"]) -> "ClassWithStaticMethod":
|
||||
return ClassWithStaticMethod(args[0].a, args[0].b)
|
||||
|
||||
# staticmethod that calls another staticmethod.
|
||||
@staticmethod
|
||||
def create_from(a: int, b: int) -> 'ClassWithStaticMethod':
|
||||
def create_from(a: int, b: int) -> "ClassWithStaticMethod":
|
||||
a = ClassWithStaticMethod(a, b)
|
||||
return ClassWithStaticMethod.create([a])
|
||||
|
||||
# Script function that calls staticmethod.
|
||||
def test_function(a: int, b: int) -> 'ClassWithStaticMethod':
|
||||
def test_function(a: int, b: int) -> "ClassWithStaticMethod":
|
||||
return ClassWithStaticMethod.create_from(a, b)
|
||||
|
||||
make_global(ClassWithStaticMethod)
|
||||
@ -1209,21 +1311,22 @@ class TestClassType(JitTestCase):
|
||||
"""
|
||||
Test classmethods on class types.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class ClassWithClassMethod:
|
||||
def __init__(self, a: int):
|
||||
self.a: int = a
|
||||
|
||||
def __eq__(self, other: 'ClassWithClassMethod'):
|
||||
def __eq__(self, other: "ClassWithClassMethod"):
|
||||
return self.a == other.a
|
||||
|
||||
@classmethod
|
||||
def create(cls, a: int) -> 'ClassWithClassMethod':
|
||||
def create(cls, a: int) -> "ClassWithClassMethod":
|
||||
return cls(a)
|
||||
|
||||
make_global(ClassWithClassMethod)
|
||||
|
||||
def test_function(a: int) -> 'ClassWithClassMethod':
|
||||
def test_function(a: int) -> "ClassWithClassMethod":
|
||||
x = ClassWithClassMethod(a)
|
||||
# Support calling classmethod with an instance
|
||||
# Calling with the class is not supported.
|
||||
@ -1236,6 +1339,7 @@ class TestClassType(JitTestCase):
|
||||
"""
|
||||
Test that a scripted class can make use of the @property decorator.
|
||||
"""
|
||||
|
||||
def free_function(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@ -1308,13 +1412,22 @@ class TestClassType(JitTestCase):
|
||||
|
||||
return self.props.attr + no_setter.attr + method_uses_property.forward()
|
||||
|
||||
self.checkModule(ModuleWithProperties(5), (5, 6, 7, 8,))
|
||||
self.checkModule(
|
||||
ModuleWithProperties(5),
|
||||
(
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
),
|
||||
)
|
||||
|
||||
def test_custom_delete(self):
|
||||
"""
|
||||
Test that del can be called on an instance of a class that
|
||||
overrides __delitem__.
|
||||
"""
|
||||
|
||||
class Example:
|
||||
def __init__(self):
|
||||
self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)}
|
||||
@ -1346,7 +1459,9 @@ class TestClassType(JitTestCase):
|
||||
del example[key]
|
||||
return example.check(key)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Class does not define __delitem__", "example[key]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"Class does not define __delitem__", "example[key]"
|
||||
):
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_recursive_script_builtin_type_resolution(self):
|
||||
@ -1369,7 +1484,7 @@ class TestClassType(JitTestCase):
|
||||
def g(self, x: device_t) -> device_ty:
|
||||
return x
|
||||
|
||||
def h(self, a: 'A') -> 'A':
|
||||
def h(self, a: "A") -> "A":
|
||||
return A()
|
||||
|
||||
def i(self, a: List[int]) -> int:
|
||||
@ -1404,14 +1519,14 @@ class TestClassType(JitTestCase):
|
||||
Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled
|
||||
when compiling a module.
|
||||
"""
|
||||
class Wrapper():
|
||||
|
||||
class Wrapper:
|
||||
def __init__(self, t):
|
||||
self.t = t
|
||||
|
||||
def to(self, l: List[torch.device], device: Optional[torch.device] = None):
|
||||
return self.t.to(device=device)
|
||||
|
||||
|
||||
class A(nn.Module):
|
||||
def forward(self):
|
||||
return Wrapper(torch.rand(4, 4))
|
||||
@ -1424,6 +1539,7 @@ class TestClassType(JitTestCase):
|
||||
Test that the error message displayed when convering a class type
|
||||
to an IValue that has an attribute of the wrong type.
|
||||
"""
|
||||
|
||||
@torch.jit.script # noqa: B903
|
||||
class ValHolder: # noqa: B903
|
||||
def __init__(self, val):
|
||||
@ -1442,7 +1558,9 @@ class TestClassType(JitTestCase):
|
||||
mod = self.mod2
|
||||
return mod.val
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not cast attribute 'val' to type Tensor", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Could not cast attribute 'val' to type Tensor", ""
|
||||
):
|
||||
torch.jit.script(Mod())
|
||||
|
||||
def test_recursive_scripting(self):
|
||||
@ -1450,6 +1568,7 @@ class TestClassType(JitTestCase):
|
||||
Test that class types are recursively scripted when an Python instance of one
|
||||
is encountered as a module attribute.
|
||||
"""
|
||||
|
||||
class Class:
|
||||
def __init__(self, a: int):
|
||||
self.a = a
|
||||
@ -1473,6 +1592,7 @@ class TestClassType(JitTestCase):
|
||||
are added as failed attributes and do not cause compilation itself
|
||||
to fail unless they are used in scripted code.
|
||||
"""
|
||||
|
||||
class UnscriptableClass:
|
||||
def __init__(self, a: int):
|
||||
self.a = a
|
||||
@ -1490,7 +1610,9 @@ class TestClassType(JitTestCase):
|
||||
def forward(self) -> bool:
|
||||
return self.obj.get_a()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to convert Python type", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "failed to convert Python type", ""
|
||||
):
|
||||
torch.jit.script(ShouldNotCompile(UnscriptableClass(4)))
|
||||
|
||||
# This Module has an attribute of type UnscriptableClass
|
||||
@ -1509,7 +1631,6 @@ class TestClassType(JitTestCase):
|
||||
|
||||
self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,))
|
||||
|
||||
|
||||
def test_unresolved_class_attributes(self):
|
||||
class UnresolvedAttrClass:
|
||||
def __init__(self):
|
||||
@ -1538,7 +1659,9 @@ class TestClassType(JitTestCase):
|
||||
u = UnresolvedAttrClass()
|
||||
return u.attr_e
|
||||
|
||||
error_message_regex = "object has no attribute or method.*is defined as a class attribute"
|
||||
error_message_regex = (
|
||||
"object has no attribute or method.*is defined as a class attribute"
|
||||
)
|
||||
for fn in (fn_a, fn_b, fn_c, fn_d, fn_e):
|
||||
with self.assertRaisesRegex(RuntimeError, error_message_regex):
|
||||
torch.jit.script(fn)
|
||||
|
||||
@ -1,19 +1,21 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import cmath
|
||||
import os
|
||||
import sys
|
||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from typing import List, Dict
|
||||
from itertools import product
|
||||
from textwrap import dedent
|
||||
import cmath
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
|
||||
class TestComplex(JitTestCase):
|
||||
def test_script(self):
|
||||
def fn(a: complex):
|
||||
@ -32,7 +34,7 @@ class TestComplex(JitTestCase):
|
||||
def fn(a: Dict[complex, complex], key: complex) -> complex:
|
||||
return a[key]
|
||||
|
||||
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
|
||||
input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
|
||||
self.checkScript(fn, (input, -4.3 - 2j))
|
||||
|
||||
def test_pickle(self):
|
||||
@ -41,7 +43,7 @@ class TestComplex(JitTestCase):
|
||||
super().__init__()
|
||||
self.a = 3 + 5j
|
||||
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
|
||||
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
|
||||
self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, b: int):
|
||||
@ -50,7 +52,7 @@ class TestComplex(JitTestCase):
|
||||
loaded = self.getExportImportCopy(ComplexModule())
|
||||
self.assertEqual(loaded.a, 3 + 5j)
|
||||
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
|
||||
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})
|
||||
self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
|
||||
self.assertEqual(loaded(2), 2 + 2j)
|
||||
|
||||
def test_complex_parse(self):
|
||||
@ -65,14 +67,19 @@ class TestComplex(JitTestCase):
|
||||
self.checkScript(fn, (t1, t2, 2))
|
||||
|
||||
def test_complex_constants_and_ops(self):
|
||||
vals = ([0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
|
||||
+ [10.0 ** i for i in range(2)] + [-(10.0 ** i) for i in range(2)])
|
||||
vals = (
|
||||
[0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
|
||||
+ [10.0**i for i in range(2)]
|
||||
+ [-(10.0**i) for i in range(2)]
|
||||
)
|
||||
complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
|
||||
|
||||
funcs_template = dedent('''
|
||||
funcs_template = dedent(
|
||||
"""
|
||||
def func(a: complex):
|
||||
return cmath.{func_or_const}(a)
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
def checkCmath(func_name, funcs_template=funcs_template):
|
||||
funcs_str = funcs_template.format(func_or_const=func_name)
|
||||
@ -80,11 +87,13 @@ class TestComplex(JitTestCase):
|
||||
execWrapper(funcs_str, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(funcs_str)
|
||||
f_script = cu.func
|
||||
f = scope['func']
|
||||
f = scope["func"]
|
||||
|
||||
if func_name in ['isinf', 'isnan', 'isfinite']:
|
||||
new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')])
|
||||
final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals))
|
||||
if func_name in ["isinf", "isnan", "isfinite"]:
|
||||
new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
|
||||
final_vals = tuple(
|
||||
complex(x, y) for x, y in product(new_vals, new_vals)
|
||||
)
|
||||
else:
|
||||
final_vals = complex_vals
|
||||
|
||||
@ -107,8 +116,27 @@ class TestComplex(JitTestCase):
|
||||
msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
|
||||
self.assertEqual(res_python, res_script, msg=msg)
|
||||
|
||||
unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh',
|
||||
'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite']
|
||||
unary_ops = [
|
||||
"log",
|
||||
"log10",
|
||||
"sqrt",
|
||||
"exp",
|
||||
"sin",
|
||||
"cos",
|
||||
"asin",
|
||||
"acos",
|
||||
"atan",
|
||||
"sinh",
|
||||
"cosh",
|
||||
"tanh",
|
||||
"asinh",
|
||||
"acosh",
|
||||
"atanh",
|
||||
"phase",
|
||||
"isinf",
|
||||
"isnan",
|
||||
"isfinite",
|
||||
]
|
||||
|
||||
# --- Unary ops ---
|
||||
for op in unary_ops:
|
||||
@ -118,7 +146,7 @@ class TestComplex(JitTestCase):
|
||||
return abs(x)
|
||||
|
||||
for val in complex_vals:
|
||||
self.checkScript(fn, (val, ))
|
||||
self.checkScript(fn, (val,))
|
||||
|
||||
def pow_complex_float(x: complex, y: float):
|
||||
return pow(x, y)
|
||||
@ -126,7 +154,6 @@ class TestComplex(JitTestCase):
|
||||
def pow_float_complex(x: float, y: complex):
|
||||
return pow(x, y)
|
||||
|
||||
|
||||
self.checkScript(pow_float_complex, (2, 3j))
|
||||
self.checkScript(pow_complex_float, (3j, 2))
|
||||
|
||||
@ -135,7 +162,7 @@ class TestComplex(JitTestCase):
|
||||
|
||||
for x, y in zip(complex_vals, complex_vals):
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/54622
|
||||
if (x == 0):
|
||||
if x == 0:
|
||||
continue
|
||||
self.checkScript(pow_complex_complex, (x, y))
|
||||
|
||||
@ -143,16 +170,25 @@ class TestComplex(JitTestCase):
|
||||
# --- Binary op ---
|
||||
def rect_fn(x: float, y: float):
|
||||
return cmath.rect(x, y)
|
||||
for x, y in product(vals, vals):
|
||||
self.checkScript(rect_fn, (x, y, ))
|
||||
|
||||
func_constants_template = dedent('''
|
||||
for x, y in product(vals, vals):
|
||||
self.checkScript(
|
||||
rect_fn,
|
||||
(
|
||||
x,
|
||||
y,
|
||||
),
|
||||
)
|
||||
|
||||
func_constants_template = dedent(
|
||||
"""
|
||||
def func():
|
||||
return cmath.{func_or_const}
|
||||
''')
|
||||
float_consts = ['pi', 'e', 'tau', 'inf', 'nan']
|
||||
complex_consts = ['infj', 'nanj']
|
||||
for x in (float_consts + complex_consts):
|
||||
"""
|
||||
)
|
||||
float_consts = ["pi", "e", "tau", "inf", "nan"]
|
||||
complex_consts = ["infj", "nanj"]
|
||||
for x in float_consts + complex_consts:
|
||||
checkCmath(x, funcs_template=func_constants_template)
|
||||
|
||||
def test_infj_nanj_pickle(self):
|
||||
@ -177,77 +213,293 @@ class TestComplex(JitTestCase):
|
||||
def fn_int(real: int, img: int):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_int, (0, 0, ))
|
||||
self.checkScript(fn_int, (-1234, 0, ))
|
||||
self.checkScript(fn_int, (0, -1256, ))
|
||||
self.checkScript(fn_int, (-167, -1256, ))
|
||||
self.checkScript(
|
||||
fn_int,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int,
|
||||
(
|
||||
-1234,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int,
|
||||
(
|
||||
0,
|
||||
-1256,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int,
|
||||
(
|
||||
-167,
|
||||
-1256,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_float(real: float, img: float):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_float, (0.0, 0.0, ))
|
||||
self.checkScript(fn_float, (-1234.78, 0, ))
|
||||
self.checkScript(fn_float, (0, 56.18, ))
|
||||
self.checkScript(fn_float, (-1.9, -19.8, ))
|
||||
self.checkScript(
|
||||
fn_float,
|
||||
(
|
||||
0.0,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float,
|
||||
(
|
||||
-1234.78,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float,
|
||||
(
|
||||
0,
|
||||
56.18,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float,
|
||||
(
|
||||
-1.9,
|
||||
-19.8,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_bool(real: bool, img: bool):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_bool, (True, True, ))
|
||||
self.checkScript(fn_bool, (False, False, ))
|
||||
self.checkScript(fn_bool, (False, True, ))
|
||||
self.checkScript(fn_bool, (True, False, ))
|
||||
self.checkScript(
|
||||
fn_bool,
|
||||
(
|
||||
True,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool,
|
||||
(
|
||||
False,
|
||||
False,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool,
|
||||
(
|
||||
False,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool,
|
||||
(
|
||||
True,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_bool_int(real: bool, img: int):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_bool_int, (True, 0, ))
|
||||
self.checkScript(fn_bool_int, (False, 0, ))
|
||||
self.checkScript(fn_bool_int, (False, -1, ))
|
||||
self.checkScript(fn_bool_int, (True, 3, ))
|
||||
self.checkScript(
|
||||
fn_bool_int,
|
||||
(
|
||||
True,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_int,
|
||||
(
|
||||
False,
|
||||
0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_int,
|
||||
(
|
||||
False,
|
||||
-1,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_int,
|
||||
(
|
||||
True,
|
||||
3,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_int_bool(real: int, img: bool):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_int_bool, (0, True, ))
|
||||
self.checkScript(fn_int_bool, (0, False, ))
|
||||
self.checkScript(fn_int_bool, (-3, True, ))
|
||||
self.checkScript(fn_int_bool, (6, False, ))
|
||||
self.checkScript(
|
||||
fn_int_bool,
|
||||
(
|
||||
0,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_bool,
|
||||
(
|
||||
0,
|
||||
False,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_bool,
|
||||
(
|
||||
-3,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_bool,
|
||||
(
|
||||
6,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_bool_float(real: bool, img: float):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_bool_float, (True, 0.0, ))
|
||||
self.checkScript(fn_bool_float, (False, 0.0, ))
|
||||
self.checkScript(fn_bool_float, (False, -1.0, ))
|
||||
self.checkScript(fn_bool_float, (True, 3.0, ))
|
||||
self.checkScript(
|
||||
fn_bool_float,
|
||||
(
|
||||
True,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_float,
|
||||
(
|
||||
False,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_float,
|
||||
(
|
||||
False,
|
||||
-1.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_bool_float,
|
||||
(
|
||||
True,
|
||||
3.0,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_float_bool(real: float, img: bool):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_float_bool, (0.0, True, ))
|
||||
self.checkScript(fn_float_bool, (0.0, False, ))
|
||||
self.checkScript(fn_float_bool, (-3.0, True, ))
|
||||
self.checkScript(fn_float_bool, (6.0, False, ))
|
||||
self.checkScript(
|
||||
fn_float_bool,
|
||||
(
|
||||
0.0,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_bool,
|
||||
(
|
||||
0.0,
|
||||
False,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_bool,
|
||||
(
|
||||
-3.0,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_bool,
|
||||
(
|
||||
6.0,
|
||||
False,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_float_int(real: float, img: int):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_float_int, (0.0, 1, ))
|
||||
self.checkScript(fn_float_int, (0.0, -1, ))
|
||||
self.checkScript(fn_float_int, (1.8, -3, ))
|
||||
self.checkScript(fn_float_int, (2.7, 8, ))
|
||||
self.checkScript(
|
||||
fn_float_int,
|
||||
(
|
||||
0.0,
|
||||
1,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_int,
|
||||
(
|
||||
0.0,
|
||||
-1,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_int,
|
||||
(
|
||||
1.8,
|
||||
-3,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_float_int,
|
||||
(
|
||||
2.7,
|
||||
8,
|
||||
),
|
||||
)
|
||||
|
||||
def fn_int_float(real: int, img: float):
|
||||
return complex(real, img)
|
||||
|
||||
self.checkScript(fn_int_float, (1, 0.0, ))
|
||||
self.checkScript(fn_int_float, (-1, 1.7, ))
|
||||
self.checkScript(fn_int_float, (-3, 0.0, ))
|
||||
self.checkScript(fn_int_float, (2, -8.9, ))
|
||||
self.checkScript(
|
||||
fn_int_float,
|
||||
(
|
||||
1,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_float,
|
||||
(
|
||||
-1,
|
||||
1.7,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_float,
|
||||
(
|
||||
-3,
|
||||
0.0,
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
fn_int_float,
|
||||
(
|
||||
2,
|
||||
-8.9,
|
||||
),
|
||||
)
|
||||
|
||||
def test_torch_complex_constructor_with_tensor(self):
|
||||
tensors = ([torch.rand(1), torch.randint(-5, 5, (1, )), torch.tensor([False])])
|
||||
tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
|
||||
|
||||
def fn_tensor_float(real, img: float):
|
||||
return complex(real, img)
|
||||
@ -280,7 +532,13 @@ class TestComplex(JitTestCase):
|
||||
return complex(real, img) + complex(2)
|
||||
|
||||
for x, y in product(tensors, tensors):
|
||||
self.checkScript(fn_tensor_tensor, (x, y, ))
|
||||
self.checkScript(
|
||||
fn_tensor_tensor,
|
||||
(
|
||||
x,
|
||||
y,
|
||||
),
|
||||
)
|
||||
|
||||
def test_comparison_ops(self):
|
||||
def fn1(a: complex, b: complex):
|
||||
@ -316,7 +574,7 @@ class TestComplex(JitTestCase):
|
||||
def fn(x: List[complex]):
|
||||
return sum(x)
|
||||
|
||||
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), ))
|
||||
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
|
||||
|
||||
def test_tensor_attributes(self):
|
||||
def tensor_real(x):
|
||||
@ -326,8 +584,8 @@ class TestComplex(JitTestCase):
|
||||
return x.imag
|
||||
|
||||
t = torch.randn(2, 3, dtype=torch.cdouble)
|
||||
self.checkScript(tensor_real, (t, ))
|
||||
self.checkScript(tensor_imag, (t, ))
|
||||
self.checkScript(tensor_real, (t,))
|
||||
self.checkScript(tensor_imag, (t,))
|
||||
|
||||
def test_binary_op_complex_tensor(self):
|
||||
def mul(x: complex, y: torch.Tensor):
|
||||
@ -350,7 +608,7 @@ class TestComplex(JitTestCase):
|
||||
|
||||
ops = [mul, add, eq, ne, sub, div]
|
||||
|
||||
for shape in [(1, ), (2, 2)]:
|
||||
for shape in [(1,), (2, 2)]:
|
||||
x = 0.71 + 0.71j
|
||||
y = torch.randn(shape, dtype=torch.cfloat)
|
||||
for op in ops:
|
||||
|
||||
@ -10,18 +10,29 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, enable_profiling_mode
|
||||
from torch.testing._internal.jit_metaprogramming_utils import try_get_nn_module_compiled_mod_and_inputs, \
|
||||
get_nn_mod_test_name, get_all_nn_module_tests, nn_functional_tests, get_nn_functional_compiled_fn_and_inputs
|
||||
from torch.testing._internal.common_utils import run_tests, set_default_dtype, suppress_warnings, IS_FBCODE
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
run_tests,
|
||||
set_default_dtype,
|
||||
suppress_warnings,
|
||||
)
|
||||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
get_all_nn_module_tests,
|
||||
get_nn_functional_compiled_fn_and_inputs,
|
||||
get_nn_mod_test_name,
|
||||
nn_functional_tests,
|
||||
try_get_nn_module_compiled_mod_and_inputs,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase
|
||||
|
||||
|
||||
def num_ifs_loops(graph):
|
||||
graph_str = str(graph)
|
||||
# only look at body of graph
|
||||
graph_body = graph_str[0:graph_str.find("return")]
|
||||
graph_body = graph_str[0 : graph_str.find("return")]
|
||||
return graph_body.count("prim::Loop") + graph_body.count("prim::If")
|
||||
|
||||
|
||||
def num_non_tensor_nodes(block):
|
||||
num_non_tensor = 0
|
||||
for node in block.nodes():
|
||||
@ -40,6 +51,7 @@ def num_non_tensor_nodes(block):
|
||||
num_non_tensor += int(not tensor_out)
|
||||
return num_non_tensor
|
||||
|
||||
|
||||
class TestComplexity(JitTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -90,5 +102,6 @@ class TestComplexity(JitTestCase):
|
||||
for line in stats:
|
||||
print(line)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2,16 +2,18 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing import FileCheck
|
||||
import unittest
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
@ -22,10 +24,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
activations = [
|
||||
F.celu,
|
||||
@ -41,6 +45,7 @@ activations = [
|
||||
F.silu,
|
||||
]
|
||||
|
||||
|
||||
class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
def test_check_no_type_promotion(self):
|
||||
dtypes = [
|
||||
@ -67,6 +72,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
|
||||
def test_functional_to_inplace_activation(self):
|
||||
for activation in activations:
|
||||
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
z = activation(y)
|
||||
@ -76,7 +82,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
self.run_pass("functional_to_inplace_activation", fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
inp = torch.rand([2, 2])
|
||||
@ -91,7 +97,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
return z
|
||||
|
||||
fn = torch.jit.script(test1)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
self.run_pass("functional_to_inplace_activation", fn.graph)
|
||||
FileCheck().check_not("aten::sigmoid_").run(fn.graph)
|
||||
|
||||
# inplace conversion should not happen because y is alias
|
||||
@ -102,7 +108,7 @@ class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
return z
|
||||
|
||||
fn = torch.jit.script(test2)
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
self.run_pass("functional_to_inplace_activation", fn.graph)
|
||||
FileCheck().check_not("aten::relu_").run(fn.graph)
|
||||
|
||||
# inplace conversion should not happen because self.x is
|
||||
@ -117,22 +123,33 @@ class TestFunctionalToInplaceActivation(JitTestCase):
|
||||
return y
|
||||
|
||||
fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
|
||||
self.run_pass('functional_to_inplace_activation', fn.graph)
|
||||
self.run_pass("functional_to_inplace_activation", fn.graph)
|
||||
FileCheck().check_not("aten::relu_").run(fn.graph)
|
||||
|
||||
@skipIfNoTorchVision
|
||||
def test_resnet18_correctness(self):
|
||||
model = torchvision.models.resnet18()
|
||||
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
||||
N, C, H, W, = 10, 3, 224, 224
|
||||
(
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
) = (
|
||||
10,
|
||||
3,
|
||||
224,
|
||||
224,
|
||||
)
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass('functional_to_inplace_activation', frozen_model.graph)
|
||||
self.run_pass("functional_to_inplace_activation", frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
||||
|
||||
class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
def test_inplace_to_functional_activation(self):
|
||||
for activation in activations:
|
||||
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
activation(y, inplace=True)
|
||||
@ -142,7 +159,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
self.run_pass('inplace_to_functional_activation', fn.graph)
|
||||
self.run_pass("inplace_to_functional_activation", fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
|
||||
|
||||
@ -151,6 +168,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
torch.sigmoid_,
|
||||
torch.tanh_,
|
||||
]:
|
||||
|
||||
def test_basic(x):
|
||||
y = x + 1
|
||||
activation(y)
|
||||
@ -160,7 +178,7 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
self.run_pass("inline", fn.graph)
|
||||
self.run_pass("constant_propagation", fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
|
||||
self.run_pass('inplace_to_functional_activation', fn.graph)
|
||||
self.run_pass("inplace_to_functional_activation", fn.graph)
|
||||
FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
|
||||
FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)
|
||||
|
||||
@ -171,7 +189,17 @@ class TestInplaceToFunctionalActivation(JitTestCase):
|
||||
def test_resnet18_correctness(self):
|
||||
model = torchvision.models.resnet18()
|
||||
frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
|
||||
N, C, H, W, = 10, 3, 224, 224
|
||||
(
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
) = (
|
||||
10,
|
||||
3,
|
||||
224,
|
||||
224,
|
||||
)
|
||||
inp = torch.randn(N, C, H, W)
|
||||
self.run_pass('inplace_to_functional_activation', frozen_model.graph)
|
||||
self.run_pass("inplace_to_functional_activation", frozen_model.graph)
|
||||
self.assertEqual(model(inp), frozen_model(inp))
|
||||
|
||||
@ -1,16 +1,21 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import unittest
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf, NoTest, TEST_CUDA
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
NoTest,
|
||||
skipCUDANonDefaultStreamIf,
|
||||
skipIfRocm,
|
||||
TEST_CUDA,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -18,7 +23,7 @@ sys.path.append(pytorch_test_dir)
|
||||
|
||||
# If GPU is not available, then do not run the tests
|
||||
if not TEST_CUDA:
|
||||
print('CUDA not available, skipping tests', file=sys.stderr)
|
||||
print("CUDA not available, skipping tests", file=sys.stderr)
|
||||
JitTestCase = NoTest # noqa: F811
|
||||
|
||||
TEST_LARGE_TENSOR = TEST_CUDA
|
||||
@ -36,10 +41,12 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestCUDA(JitTestCase):
|
||||
"""
|
||||
A suite of tests for the CUDA API in TorchScript.
|
||||
"""
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@ -54,10 +61,10 @@ class TestCUDA(JitTestCase):
|
||||
def test_device_synchronize():
|
||||
prev_current_device_index = torch.cuda.current_device()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize('cuda')
|
||||
torch.cuda.synchronize('cuda:0')
|
||||
torch.cuda.synchronize("cuda")
|
||||
torch.cuda.synchronize("cuda:0")
|
||||
torch.cuda.synchronize(0)
|
||||
torch.cuda.synchronize(torch.device('cuda:1'))
|
||||
torch.cuda.synchronize(torch.device("cuda:1"))
|
||||
after_current_device_index = torch.cuda.current_device()
|
||||
|
||||
# Check if the current device index is same as the device index before
|
||||
@ -66,7 +73,7 @@ class TestCUDA(JitTestCase):
|
||||
|
||||
@torch.jit.script
|
||||
def test_multi_device_synchronize():
|
||||
torch.cuda.synchronize(torch.device('cuda:0'))
|
||||
torch.cuda.synchronize(torch.device("cuda:0"))
|
||||
prev_current_device_index = torch.cuda.current_device()
|
||||
torch.cuda.synchronize(1)
|
||||
after_current_device_index = torch.cuda.current_device()
|
||||
@ -76,11 +83,9 @@ class TestCUDA(JitTestCase):
|
||||
return prev_current_device_index == after_current_device_index
|
||||
|
||||
self.assertTrue(test_device_synchronize)
|
||||
FileCheck().check("cuda::synchronize(") \
|
||||
.run(test_device_synchronize.graph)
|
||||
FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph)
|
||||
self.assertTrue(test_multi_device_synchronize)
|
||||
FileCheck().check("cuda::synchronize(") \
|
||||
.run(test_multi_device_synchronize.graph)
|
||||
FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph)
|
||||
|
||||
def test_stream_args(self):
|
||||
# Test stream creation with default arguments
|
||||
@ -165,7 +170,6 @@ class TestCUDA(JitTestCase):
|
||||
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
|
||||
@skipCUDANonDefaultStreamIf(True)
|
||||
def test_streams_and_events(self):
|
||||
|
||||
# Test default_stream API by passing device ID as an argument and
|
||||
# and check if the stream device index matches with the device ID
|
||||
@torch.jit.script
|
||||
@ -182,14 +186,14 @@ class TestCUDA(JitTestCase):
|
||||
# This test checks for the default stream ID is set to 0 on the device
|
||||
@torch.jit.script
|
||||
def test_default_streams():
|
||||
s0 = torch.cuda.default_stream(torch.device('cuda:0'))
|
||||
s1 = torch.cuda.default_stream(torch.device('cuda:1'))
|
||||
s0 = torch.cuda.default_stream(torch.device("cuda:0"))
|
||||
s1 = torch.cuda.default_stream(torch.device("cuda:1"))
|
||||
|
||||
d = torch.device('cuda:1')
|
||||
d = torch.device("cuda:1")
|
||||
|
||||
# Check the current stream id and default id are same
|
||||
# on the current device. The current device id by default is 0
|
||||
s2 = torch.cuda.current_stream(torch.device('cuda:0'))
|
||||
s2 = torch.cuda.current_stream(torch.device("cuda:0"))
|
||||
check_s2 = s2.id() == s0.id()
|
||||
check_d0 = torch.cuda.current_device() == s2.device_index()
|
||||
|
||||
@ -203,9 +207,25 @@ class TestCUDA(JitTestCase):
|
||||
# Check if the current device was reset to 0
|
||||
is_device_d0 = torch.cuda.current_device() == s2.device_index()
|
||||
|
||||
return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
|
||||
return (
|
||||
s0.device_index(),
|
||||
s1.device_index(),
|
||||
check_s2,
|
||||
check_s3,
|
||||
check_d0,
|
||||
check_d1,
|
||||
is_device_d0,
|
||||
)
|
||||
|
||||
d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
|
||||
(
|
||||
d0,
|
||||
d1,
|
||||
check_s2,
|
||||
check_s3,
|
||||
check_d0,
|
||||
check_d1,
|
||||
is_device_d0,
|
||||
) = test_default_streams()
|
||||
|
||||
self.assertEqual(d0, 0)
|
||||
self.assertEqual(d1, 1)
|
||||
@ -228,12 +248,21 @@ class TestCUDA(JitTestCase):
|
||||
with torch.cuda.stream(None):
|
||||
cur_device_index = torch.cuda.current_device()
|
||||
is_device_index_same = cur_device_index == device_index
|
||||
is_current_stream_same = torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
is_default_stream_same = torch.cuda.default_stream(device).id() == default_stream.id()
|
||||
is_current_stream_same = (
|
||||
torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
)
|
||||
is_default_stream_same = (
|
||||
torch.cuda.default_stream(device).id() == default_stream.id()
|
||||
)
|
||||
|
||||
# Check if the device index, current stream and default streams have not changed
|
||||
are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
|
||||
are_streams_same = (
|
||||
is_device_index_same
|
||||
and is_current_stream_same
|
||||
and is_default_stream_same
|
||||
)
|
||||
return are_streams_same
|
||||
|
||||
self.assertTrue(test_set_none_stream())
|
||||
|
||||
# This test checks if the Device Context manager is a no op
|
||||
@ -246,6 +275,7 @@ class TestCUDA(JitTestCase):
|
||||
# Check if the current device is the same
|
||||
is_device_same = torch.cuda.current_device() == device_index
|
||||
return is_device_same
|
||||
|
||||
self.assertTrue(test_set_device_none())
|
||||
|
||||
# Check if a CUDA JIT stream is created
|
||||
@ -260,15 +290,15 @@ class TestCUDA(JitTestCase):
|
||||
|
||||
# Class used to store results for the test: test_get_stream.
|
||||
class Result(NamedTuple):
|
||||
t1 : torch.Tensor
|
||||
t2 : torch.Tensor
|
||||
is_current_and_default_stream_same : bool
|
||||
is_default_and_user_stream_not_same : bool
|
||||
is_stream_set : bool
|
||||
is_stream_reset : bool
|
||||
default_stream_query : bool
|
||||
default_stream_id : int
|
||||
user_stream_id : int
|
||||
t1: torch.Tensor
|
||||
t2: torch.Tensor
|
||||
is_current_and_default_stream_same: bool
|
||||
is_default_and_user_stream_not_same: bool
|
||||
is_stream_set: bool
|
||||
is_stream_reset: bool
|
||||
default_stream_query: bool
|
||||
default_stream_id: int
|
||||
user_stream_id: int
|
||||
|
||||
# The test aims at checking different stream proporties.
|
||||
@torch.jit.script
|
||||
@ -280,15 +310,23 @@ class TestCUDA(JitTestCase):
|
||||
user_stream = torch.cuda.Stream()
|
||||
|
||||
# Check if the current and default streams are the same on the device
|
||||
is_current_and_default_stream_same = current_stream.id() == default_stream.id()
|
||||
is_current_and_default_stream_same = (
|
||||
current_stream.id() == default_stream.id()
|
||||
)
|
||||
# Check if user stream and default stream are not the same on the device
|
||||
is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
|
||||
is_default_and_user_stream_not_same = (
|
||||
default_stream.id() != user_stream.id()
|
||||
)
|
||||
|
||||
with torch.cuda.stream(user_stream):
|
||||
is_stream_set = torch.cuda.current_stream(device).id() == user_stream.id()
|
||||
is_stream_set = (
|
||||
torch.cuda.current_stream(device).id() == user_stream.id()
|
||||
)
|
||||
|
||||
# Check if the stream was reset to current_stream
|
||||
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
is_stream_reset = (
|
||||
torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
)
|
||||
|
||||
tensor1 = torch.rand(10000, 10000, device="cuda")
|
||||
tensor2 = torch.mm(tensor1, tensor1).to("cuda")
|
||||
@ -297,9 +335,16 @@ class TestCUDA(JitTestCase):
|
||||
|
||||
# Capture all the results in the class Result
|
||||
res = Result(
|
||||
tensor1, tensor2, is_current_and_default_stream_same,
|
||||
is_default_and_user_stream_not_same, is_stream_set,
|
||||
is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
|
||||
tensor1,
|
||||
tensor2,
|
||||
is_current_and_default_stream_same,
|
||||
is_default_and_user_stream_not_same,
|
||||
is_stream_set,
|
||||
is_stream_reset,
|
||||
default_stream_query,
|
||||
default_stream.id(),
|
||||
user_stream.id(),
|
||||
)
|
||||
return res
|
||||
|
||||
result = test_get_stream()
|
||||
@ -310,8 +355,12 @@ class TestCUDA(JitTestCase):
|
||||
self.assertTrue(result.is_stream_set)
|
||||
self.assertTrue(result.is_stream_reset)
|
||||
self.assertTrue(result.default_stream_query)
|
||||
self.assertEqual(result.default_stream_id, 0) # Check if the default stream ID is always 0
|
||||
self.assertNotEqual(result.user_stream_id, 0) # Check if the user stream is always non zero
|
||||
self.assertEqual(
|
||||
result.default_stream_id, 0
|
||||
) # Check if the default stream ID is always 0
|
||||
self.assertNotEqual(
|
||||
result.user_stream_id, 0
|
||||
) # Check if the user stream is always non zero
|
||||
|
||||
# Test the stream context manager. This test checks if the stream is switched
|
||||
# to the user stream on using the stream context manager.
|
||||
@ -329,14 +378,20 @@ class TestCUDA(JitTestCase):
|
||||
# Wait for B to be computed
|
||||
user_stream.synchronize()
|
||||
# Check if the stream has been reset on the current device
|
||||
is_stream_reset = torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
is_stream_reset = (
|
||||
torch.cuda.current_stream(device).id() == current_stream.id()
|
||||
)
|
||||
|
||||
return A, B, check, is_stream_reset
|
||||
|
||||
A, B, is_stream_set, is_stream_reset = test_stream_context()
|
||||
self.assertEqual(torch.matmul(A, A), B)
|
||||
self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
|
||||
self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
|
||||
self.assertTrue(
|
||||
is_stream_set, "Error: Current stream was not set to user stream!"
|
||||
)
|
||||
self.assertTrue(
|
||||
is_stream_reset, "Error: The stream was not restored to previous stream!"
|
||||
)
|
||||
|
||||
# Test multiple nested streams. Check if the operations are computed as expected on the streams
|
||||
# This test has been adapted from the eager mode tests available at test/test_cuda.py
|
||||
@ -372,11 +427,24 @@ class TestCUDA(JitTestCase):
|
||||
|
||||
# Check if the stream and device has been restored to previous stream and device
|
||||
is_device_current = torch.cuda.current_device() == prev_device_index
|
||||
is_stream_current = torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
is_stream_current = (
|
||||
torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
)
|
||||
|
||||
check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
|
||||
check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
|
||||
check_stream = (
|
||||
is_stream_s1
|
||||
and is_stream_s2
|
||||
and is_stream_s1_after
|
||||
and is_stream_current
|
||||
)
|
||||
check_device = (
|
||||
is_device_s1
|
||||
and is_device_s2
|
||||
and is_device_s1_after
|
||||
and is_device_current
|
||||
)
|
||||
return A, B, C, D, check_stream, check_device
|
||||
|
||||
A, B, C, D, check_stream, check_device = test_multiple_stream()
|
||||
|
||||
self.assertEqual(torch.matmul(A, A), C)
|
||||
@ -401,7 +469,9 @@ class TestCUDA(JitTestCase):
|
||||
B = torch.mm(A, A).to("cuda")
|
||||
s1.record_event(event)
|
||||
# Check if the current_stream is reset
|
||||
is_current_stream_1 = torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
is_current_stream_1 = (
|
||||
torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
)
|
||||
# Wait for ops on s1 to be computed
|
||||
s2.wait_event(event)
|
||||
with torch.cuda.stream(s2):
|
||||
@ -410,9 +480,16 @@ class TestCUDA(JitTestCase):
|
||||
# Wait for C to be computed
|
||||
s2.synchronize()
|
||||
# Check if the current_stream is reset
|
||||
is_current_stream_2 = torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
is_current_stream_2 = (
|
||||
torch.cuda.current_stream(device).id() == prev_current_stream.id()
|
||||
)
|
||||
|
||||
check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
|
||||
check_stream = (
|
||||
is_current_stream_1
|
||||
and is_current_stream_2
|
||||
and is_stream_s1
|
||||
and is_stream_s2
|
||||
)
|
||||
return A, B, C, check_stream
|
||||
|
||||
A, B, C, check_stream = test_data_dependency_between_streams()
|
||||
@ -425,6 +502,7 @@ class TestCUDA(JitTestCase):
|
||||
def test_simple_event():
|
||||
e = torch.cuda.Event(True, False, False)
|
||||
return e is not None
|
||||
|
||||
self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
|
||||
|
||||
# Record the CUDA event for operation torch.mm on the current stream
|
||||
@ -474,6 +552,7 @@ class TestCUDA(JitTestCase):
|
||||
# not necessary to check e_tik and e_tok, as elapsed_time would throw
|
||||
# exception if otherwise.
|
||||
return e_tik.elapsed_time(e_tok)
|
||||
|
||||
self.assertGreater(test_stream_synchronize(), 0)
|
||||
|
||||
# Test event synchronization for the event that records a stream doing
|
||||
@ -536,12 +615,13 @@ class TestCUDA(JitTestCase):
|
||||
# not necessary to check e_tik and e_tok, as elapsed_time would throw
|
||||
# exception if otherwise.
|
||||
return e_tik.elapsed_time(e_tok)
|
||||
|
||||
self.assertGreater(test_event_wait(), 0)
|
||||
|
||||
# Test for stream wait_event. Checks if the stream waits on the event
|
||||
@torch.jit.script
|
||||
def test_wait_event():
|
||||
d1 = torch.device('cuda:1')
|
||||
d1 = torch.device("cuda:1")
|
||||
|
||||
with torch.cuda.device(d1):
|
||||
s0 = torch.cuda.current_stream(d1)
|
||||
@ -550,11 +630,12 @@ class TestCUDA(JitTestCase):
|
||||
e0 = torch.cuda.Event(False, False, False)
|
||||
s0.record_event(e0)
|
||||
|
||||
s1 = torch.cuda.current_stream(torch.device('cuda:0'))
|
||||
s1 = torch.cuda.current_stream(torch.device("cuda:0"))
|
||||
s1.wait_event(e0)
|
||||
s1.synchronize()
|
||||
|
||||
return e0.query() and s0.query() and s1.query()
|
||||
|
||||
self.assertTrue(test_wait_event())
|
||||
|
||||
# Test if a scripted module with cuda streams can be saved, loaded and executed
|
||||
|
||||
@ -11,33 +11,37 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def canonical(graph):
|
||||
return torch._C._jit_pass_canonicalize(graph).str(False)
|
||||
|
||||
class TestCustomOperators(JitTestCase):
|
||||
|
||||
class TestCustomOperators(JitTestCase):
|
||||
def test_dynamic_op_registry(self):
|
||||
from torch._ops import _OpNamespace
|
||||
self.assertTrue(hasattr(torch, 'ops'))
|
||||
|
||||
if '_test' in torch.ops.__dict__:
|
||||
torch.ops.__dict__.pop('_test')
|
||||
self.assertTrue(hasattr(torch, "ops"))
|
||||
|
||||
if "_test" in torch.ops.__dict__:
|
||||
torch.ops.__dict__.pop("_test")
|
||||
|
||||
# Don't use `hasattr()` because it will call `__getattr__`.
|
||||
self.assertNotIn('_test', torch.ops.__dict__)
|
||||
self.assertNotIn("_test", torch.ops.__dict__)
|
||||
torch.ops._test
|
||||
self.assertIn('_test', torch.ops.__dict__)
|
||||
self.assertIn("_test", torch.ops.__dict__)
|
||||
self.assertEqual(type(torch.ops._test), _OpNamespace)
|
||||
|
||||
self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
|
||||
self.assertNotIn("leaky_relu", torch.ops._test.__dict__)
|
||||
op = torch.ops._test.leaky_relu
|
||||
self.assertTrue(callable(op))
|
||||
self.assertIn('leaky_relu', torch.ops._test.__dict__)
|
||||
self.assertIn("leaky_relu", torch.ops._test.__dict__)
|
||||
op2 = torch.ops._test.leaky_relu
|
||||
self.assertEqual(op, op2)
|
||||
|
||||
@ -46,7 +50,7 @@ class TestCustomOperators(JitTestCase):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
AttributeError,
|
||||
f"Invalid attribute '{attr}' for '_OpNamespace' '_test'",
|
||||
""
|
||||
"",
|
||||
):
|
||||
getattr(torch.ops._test, attr)
|
||||
|
||||
@ -63,15 +67,13 @@ class TestCustomOperators(JitTestCase):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)",
|
||||
""
|
||||
"",
|
||||
):
|
||||
torch.ops.aten.relu(1, 2)
|
||||
|
||||
def test_passing_too_few_args(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
r"aten::relu\(\) is missing value for argument 'self'.",
|
||||
""
|
||||
RuntimeError, r"aten::relu\(\) is missing value for argument 'self'.", ""
|
||||
):
|
||||
torch.ops.aten.relu()
|
||||
|
||||
@ -79,7 +81,7 @@ class TestCustomOperators(JitTestCase):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
r"aten::type_as\(\) is missing value for argument 'other'.",
|
||||
""
|
||||
"",
|
||||
):
|
||||
torch.ops.aten.type_as(torch.ones(5, 5))
|
||||
|
||||
@ -87,7 +89,7 @@ class TestCustomOperators(JitTestCase):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'",
|
||||
""
|
||||
"",
|
||||
):
|
||||
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
|
||||
|
||||
@ -102,6 +104,7 @@ class TestCustomOperators(JitTestCase):
|
||||
@torch.jit.script
|
||||
def func(x):
|
||||
return torch.ops.aten.relu(x)
|
||||
|
||||
input = torch.ones(5, 5)
|
||||
self.assertEqual(func(input), input.relu())
|
||||
|
||||
@ -110,28 +113,37 @@ class TestCustomOperators(JitTestCase):
|
||||
func = torch.jit.trace(torch.ops.aten.relu, [input])
|
||||
self.assertEqual(func(input), input.relu())
|
||||
|
||||
@unittest.skip("Need to figure out default dtype differences between fbcode and oss")
|
||||
@unittest.skip(
|
||||
"Need to figure out default dtype differences between fbcode and oss"
|
||||
)
|
||||
def test_script_graph_for_custom_ops_matches_traced_graph(self):
|
||||
input = torch.ones(5, 5)
|
||||
trace = torch.jit.trace(torch.ops.aten.relu, [input])
|
||||
self.assertExpectedInline(canonical(trace.graph), '''\
|
||||
self.assertExpectedInline(
|
||||
canonical(trace.graph),
|
||||
"""\
|
||||
graph(%0 : Float(5, 5)):
|
||||
%1 : Float(5, 5) = aten::relu(%0)
|
||||
return (%1)
|
||||
''')
|
||||
""",
|
||||
)
|
||||
|
||||
def test_script_graph_contains_custom_op(self):
|
||||
@torch.jit.script
|
||||
def func(x):
|
||||
return torch.ops.aten.relu(x)
|
||||
self.assertExpectedInline(canonical(func.graph), '''\
|
||||
|
||||
self.assertExpectedInline(
|
||||
canonical(func.graph),
|
||||
"""\
|
||||
graph(%x.1 : Tensor):
|
||||
%1 : Tensor = aten::relu(%x.1)
|
||||
return (%1)
|
||||
''')
|
||||
""",
|
||||
)
|
||||
|
||||
def test_generic_list(self):
|
||||
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
|
||||
self.assertEqual(torch.ops._test.get_first([["hello"]]), "hello")
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/80508
|
||||
def test_where_no_scalar(self):
|
||||
|
||||
@ -13,17 +13,21 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestDataParallel(JitTestCase):
|
||||
class Mpy(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestDataParallel.Mpy, self).__init__()
|
||||
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
|
||||
nn.ReLU(), nn.Linear(2, 2))
|
||||
self.m = nn.Sequential(
|
||||
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, input):
|
||||
@ -50,13 +54,13 @@ class TestDataParallel(JitTestCase):
|
||||
return self.m2(x)
|
||||
|
||||
class Msm(torch.jit.ScriptModule):
|
||||
|
||||
__constants__ = ['m']
|
||||
__constants__ = ["m"]
|
||||
|
||||
def __init__(self):
|
||||
super(TestDataParallel.Msm, self).__init__()
|
||||
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
|
||||
nn.ReLU(), nn.Linear(2, 2))
|
||||
self.m = nn.Sequential(
|
||||
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
@ -140,7 +144,7 @@ class TestDataParallel(JitTestCase):
|
||||
# Use .data here to avoid version counter bump.
|
||||
# The graph created by the following forward will be wrong but
|
||||
# we never backward through them so it's fine
|
||||
p.data -= 1. * p.grad
|
||||
p.data -= 1.0 * p.grad
|
||||
second_forward = module(x)
|
||||
|
||||
# replica which is on the same GPU has a shallow copy of the original
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from hypothesis import given, settings, strategies as st
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from typing import List, Optional
|
||||
import sys
|
||||
import torch
|
||||
import unittest
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# Example jittable dataclass
|
||||
@dataclass(order=True)
|
||||
@ -20,8 +22,8 @@ class Point:
|
||||
def __post_init__(self):
|
||||
self.norm = (torch.tensor(self.x) ** 2 + torch.tensor(self.y) ** 2) ** 0.5
|
||||
|
||||
class MixupScheme(Enum):
|
||||
|
||||
class MixupScheme(Enum):
|
||||
INPUT = ["input"]
|
||||
|
||||
MANIFOLD = [
|
||||
@ -38,6 +40,7 @@ class MixupParams:
|
||||
self.alpha = alpha
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
class MixupScheme2(Enum):
|
||||
A = 1
|
||||
B = 2
|
||||
@ -49,6 +52,7 @@ class MixupParams2:
|
||||
self.alpha = alpha
|
||||
self.scheme = scheme
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixupParams3:
|
||||
def __init__(self, alpha: float = 0.125, scheme: MixupScheme2 = MixupScheme2.A):
|
||||
@ -59,11 +63,11 @@ class MixupParams3:
|
||||
# Make sure the Meta internal tooling doesn't raise an overflow error
|
||||
NonHugeFloats = st.floats(min_value=-1e4, max_value=1e4, allow_nan=False)
|
||||
|
||||
class TestDataclasses(JitTestCase):
|
||||
|
||||
class TestDataclasses(JitTestCase):
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
torch._C._jit_clear_class_registry()
|
||||
torch._C._jit_clear_class_registry()
|
||||
|
||||
def test_init_vars(self):
|
||||
@torch.jit.script
|
||||
@ -75,7 +79,9 @@ class TestDataclasses(JitTestCase):
|
||||
norm: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self, norm_p: int):
|
||||
self.norm = (torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p) ** (1 / norm_p)
|
||||
self.norm = (
|
||||
torch.tensor(self.x) ** norm_p + torch.tensor(self.y) ** norm_p
|
||||
) ** (1 / norm_p)
|
||||
|
||||
def fn(x: float, y: float, p: int):
|
||||
pt = Point2(x, y, p)
|
||||
@ -88,6 +94,7 @@ class TestDataclasses(JitTestCase):
|
||||
@given(NonHugeFloats, NonHugeFloats)
|
||||
def test__post_init__(self, x, y):
|
||||
P = torch.jit.script(Point)
|
||||
|
||||
def fn(x: float, y: float):
|
||||
pt = P(x, y)
|
||||
return pt.norm
|
||||
@ -95,7 +102,9 @@ class TestDataclasses(JitTestCase):
|
||||
self.checkScript(fn, [x, y])
|
||||
|
||||
@settings(deadline=None)
|
||||
@given(st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats))
|
||||
@given(
|
||||
st.tuples(NonHugeFloats, NonHugeFloats), st.tuples(NonHugeFloats, NonHugeFloats)
|
||||
)
|
||||
def test_comparators(self, pt1, pt2):
|
||||
x1, y1 = pt1
|
||||
x2, y2 = pt2
|
||||
@ -122,6 +131,7 @@ class TestDataclasses(JitTestCase):
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
torch.jit.script(Foo)
|
||||
|
||||
def fn():
|
||||
foo = Foo()
|
||||
return foo.x
|
||||
@ -137,7 +147,7 @@ class TestDataclasses(JitTestCase):
|
||||
a: int
|
||||
b: int
|
||||
|
||||
def __eq__(self, other: 'CustomEq') -> bool:
|
||||
def __eq__(self, other: "CustomEq") -> bool:
|
||||
return self.a == other.a # ignore the b field
|
||||
|
||||
def fn(a: int, b1: int, b2: int):
|
||||
@ -154,9 +164,7 @@ class TestDataclasses(JitTestCase):
|
||||
|
||||
torch.jit.script(MixupParams2) # don't throw
|
||||
|
||||
|
||||
def test_use_unregistered_dataclass_raises(self):
|
||||
|
||||
def f(a: MixupParams3):
|
||||
return 0
|
||||
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from itertools import product
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from torch.jit._passes._property_propagation import apply_input_props_using_example
|
||||
from torch.testing._internal.common_utils import TEST_CUDA
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.jit._passes._property_propagation import apply_input_props_using_example
|
||||
|
||||
try:
|
||||
from torchvision import models
|
||||
|
||||
@ -7,19 +7,19 @@ from unittest.case import expectedFailure
|
||||
import torch
|
||||
from torch import complex32, float32, float64, int32, int64
|
||||
from torch.jit._passes import _property_propagation
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
SampleInput,
|
||||
op_db,
|
||||
sample_inputs_adaptive_avg_pool2d,
|
||||
sample_inputs_conv2d,
|
||||
SampleInput,
|
||||
)
|
||||
from torch.testing._internal.common_utils import set_default_dtype, first_sample
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import first_sample, set_default_dtype
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn
|
||||
from torch.testing._internal.common_device_type import (
|
||||
ops,
|
||||
instantiate_device_type_tests,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
"""
|
||||
Dtype Analysis relies on symbolic shape analysis, which is still in beta
|
||||
@ -274,7 +274,9 @@ class TestDtypeAnalysis(TestDtypeBase):
|
||||
):
|
||||
for dtype in (torch.int8, torch.float64):
|
||||
# Gets default version for conv2d
|
||||
sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[-1]
|
||||
sample_input: SampleInput = list(inputs_fn(None, "cpu", dtype, False))[
|
||||
-1
|
||||
]
|
||||
input_args = [sample_input.input, *sample_input.args]
|
||||
self.assert_dtype_equal_custom_args(fn, input_args)
|
||||
|
||||
@ -352,7 +354,9 @@ class TestDtypeCustomRules(TestDtypeBase):
|
||||
# Run the Dtype Analysis
|
||||
graph = traced_fn.graph # Note this is a cached graph
|
||||
input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
|
||||
input_tensors += [v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)]
|
||||
input_tensors += [
|
||||
v for v in sample_input.kwargs.values() if isinstance(v, torch.Tensor)
|
||||
]
|
||||
self.prop_dtype_on_graph(graph, input_tensors)
|
||||
self.assert_output_dtype_equal(expected_res, graph)
|
||||
|
||||
|
||||
@ -2,21 +2,24 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from enum import Enum
|
||||
from typing import Any, List
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestEnum(JitTestCase):
|
||||
def test_enum_value_types(self):
|
||||
@ -38,11 +41,9 @@ class TestEnum(JitTestCase):
|
||||
def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
|
||||
return (a.name, b.name, c.name)
|
||||
|
||||
FileCheck() \
|
||||
.check("IntEnum") \
|
||||
.check("FloatEnum") \
|
||||
.check("StringEnum") \
|
||||
.run(str(supported_enum_types.graph))
|
||||
FileCheck().check("IntEnum").check("FloatEnum").check("StringEnum").run(
|
||||
str(supported_enum_types.graph)
|
||||
)
|
||||
|
||||
class TensorEnum(Enum):
|
||||
FOO = torch.tensor(0)
|
||||
@ -54,7 +55,9 @@ class TestEnum(JitTestCase):
|
||||
return a.name
|
||||
|
||||
# TODO: rewrite code so that the highlight is not empty.
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Cannot create Enum with value type 'Tensor'", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Cannot create Enum with value type 'Tensor'", ""
|
||||
):
|
||||
torch.jit.script(unsupported_enum_types)
|
||||
|
||||
def test_enum_comp(self):
|
||||
@ -88,11 +91,9 @@ class TestEnum(JitTestCase):
|
||||
def enum_comp(x: Foo) -> bool:
|
||||
return x == Bar.ITEM1
|
||||
|
||||
FileCheck() \
|
||||
.check("prim::Constant") \
|
||||
.check_same("Bar.ITEM1") \
|
||||
.check("aten::eq") \
|
||||
.run(str(enum_comp.graph))
|
||||
FileCheck().check("prim::Constant").check_same("Bar.ITEM1").check(
|
||||
"aten::eq"
|
||||
).run(str(enum_comp.graph))
|
||||
|
||||
self.assertEqual(enum_comp(Foo.ITEM1), False)
|
||||
|
||||
@ -107,7 +108,9 @@ class TestEnum(JitTestCase):
|
||||
return x == y
|
||||
|
||||
# TODO: rewrite code so that the highlight is not empty.
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Could not unify type list", ""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Could not unify type list", ""
|
||||
):
|
||||
torch.jit.script(enum_comp)
|
||||
|
||||
def test_enum_name(self):
|
||||
@ -121,11 +124,9 @@ class TestEnum(JitTestCase):
|
||||
def enum_name(x: Color) -> str:
|
||||
return x.name
|
||||
|
||||
FileCheck() \
|
||||
.check("Color") \
|
||||
.check_next("prim::EnumName") \
|
||||
.check_next("return") \
|
||||
.run(str(enum_name.graph))
|
||||
FileCheck().check("Color").check_next("prim::EnumName").check_next(
|
||||
"return"
|
||||
).run(str(enum_name.graph))
|
||||
|
||||
self.assertEqual(enum_name(Color.RED), Color.RED.name)
|
||||
self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
|
||||
@ -141,11 +142,9 @@ class TestEnum(JitTestCase):
|
||||
def enum_value(x: Color) -> int:
|
||||
return x.value
|
||||
|
||||
FileCheck() \
|
||||
.check("Color") \
|
||||
.check_next("prim::EnumValue") \
|
||||
.check_next("return") \
|
||||
.run(str(enum_value.graph))
|
||||
FileCheck().check("Color").check_next("prim::EnumValue").check_next(
|
||||
"return"
|
||||
).run(str(enum_value.graph))
|
||||
|
||||
self.assertEqual(enum_value(Color.RED), Color.RED.value)
|
||||
self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
|
||||
@ -161,11 +160,9 @@ class TestEnum(JitTestCase):
|
||||
def enum_const(x: Color) -> bool:
|
||||
return x == Color.RED
|
||||
|
||||
FileCheck() \
|
||||
.check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
|
||||
.check_next("aten::eq") \
|
||||
.check_next("return") \
|
||||
.run(str(enum_const.graph))
|
||||
FileCheck().check(
|
||||
"prim::Constant[value=__torch__.jit.test_enum.Color.RED]"
|
||||
).check_next("aten::eq").check_next("return").run(str(enum_const.graph))
|
||||
|
||||
self.assertEqual(enum_const(Color.RED), True)
|
||||
self.assertEqual(enum_const(Color.GREEN), False)
|
||||
@ -183,7 +180,9 @@ class TestEnum(JitTestCase):
|
||||
else:
|
||||
return False
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "has no attribute 'PURPLE'", "Color.PURPLE"
|
||||
):
|
||||
torch.jit.script(enum_const)
|
||||
|
||||
def test_enum_ivalue_type(self):
|
||||
@ -197,10 +196,9 @@ class TestEnum(JitTestCase):
|
||||
def is_color_enum(x: Any):
|
||||
return isinstance(x, Color)
|
||||
|
||||
FileCheck() \
|
||||
.check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
|
||||
.check_next("return") \
|
||||
.run(str(is_color_enum.graph))
|
||||
FileCheck().check(
|
||||
"prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]"
|
||||
).check_next("return").run(str(is_color_enum.graph))
|
||||
|
||||
self.assertEqual(is_color_enum(Color.RED), True)
|
||||
self.assertEqual(is_color_enum(Color.GREEN), True)
|
||||
@ -217,10 +215,9 @@ class TestEnum(JitTestCase):
|
||||
def closed_over_aliased_type():
|
||||
return a.RED.value
|
||||
|
||||
FileCheck() \
|
||||
.check("prim::Constant[value={}]".format(a.RED.value)) \
|
||||
.check_next("return") \
|
||||
.run(str(closed_over_aliased_type.graph))
|
||||
FileCheck().check("prim::Constant[value={}]".format(a.RED.value)).check_next(
|
||||
"return"
|
||||
).run(str(closed_over_aliased_type.graph))
|
||||
|
||||
self.assertEqual(closed_over_aliased_type(), Color.RED.value)
|
||||
|
||||
@ -230,10 +227,9 @@ class TestEnum(JitTestCase):
|
||||
def closed_over_aliased_value():
|
||||
return b.value
|
||||
|
||||
FileCheck() \
|
||||
.check("prim::Constant[value={}]".format(b.value)) \
|
||||
.check_next("return") \
|
||||
.run(str(closed_over_aliased_value.graph))
|
||||
FileCheck().check("prim::Constant[value={}]".format(b.value)).check_next(
|
||||
"return"
|
||||
).run(str(closed_over_aliased_value.graph))
|
||||
|
||||
self.assertEqual(closed_over_aliased_value(), Color.RED.value)
|
||||
|
||||
@ -253,13 +249,9 @@ class TestEnum(JitTestCase):
|
||||
m = TestModule(Color.RED)
|
||||
scripted = torch.jit.script(m)
|
||||
|
||||
FileCheck() \
|
||||
.check("TestModule") \
|
||||
.check_next("Color") \
|
||||
.check_same("prim::GetAttr[name=\"e\"]") \
|
||||
.check_next("prim::EnumValue") \
|
||||
.check_next("return") \
|
||||
.run(str(scripted.graph))
|
||||
FileCheck().check("TestModule").check_next("Color").check_same(
|
||||
'prim::GetAttr[name="e"]'
|
||||
).check_next("prim::EnumValue").check_next("return").run(str(scripted.graph))
|
||||
|
||||
self.assertEqual(scripted(), Color.RED.value)
|
||||
|
||||
@ -316,16 +308,12 @@ class TestEnum(JitTestCase):
|
||||
m = TestModule(Color.RED)
|
||||
scripted = torch.jit.script(m)
|
||||
|
||||
FileCheck() \
|
||||
.check("TestModule") \
|
||||
.check_next("Color") \
|
||||
.check_same("prim::GetAttr[name=\"e\"]") \
|
||||
.check_next("return") \
|
||||
.run(str(scripted.graph))
|
||||
FileCheck().check("TestModule").check_next("Color").check_same(
|
||||
'prim::GetAttr[name="e"]'
|
||||
).check_next("return").run(str(scripted.graph))
|
||||
|
||||
self.assertEqual(scripted(), Color.RED)
|
||||
|
||||
|
||||
def test_enum_iterate(self):
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
@ -342,12 +330,9 @@ class TestEnum(JitTestCase):
|
||||
make_global(Color)
|
||||
scripted = torch.jit.script(iterate_enum)
|
||||
|
||||
FileCheck() \
|
||||
.check("Enum<__torch__.jit.test_enum.Color>[]") \
|
||||
.check_same("Color.RED") \
|
||||
.check_same("Color.GREEN") \
|
||||
.check_same("Color.BLUE") \
|
||||
.run(str(scripted.graph))
|
||||
FileCheck().check("Enum<__torch__.jit.test_enum.Color>[]").check_same(
|
||||
"Color.RED"
|
||||
).check_same("Color.GREEN").check_same("Color.BLUE").run(str(scripted.graph))
|
||||
|
||||
# PURPLE always appears last because we follow Python's Enum definition order.
|
||||
self.assertEqual(scripted(Color.RED), [Color.GREEN.value, Color.BLUE.value])
|
||||
@ -355,7 +340,6 @@ class TestEnum(JitTestCase):
|
||||
|
||||
# Tests that explicitly and/or repeatedly scripting an Enum class is permitted.
|
||||
def test_enum_explicit_script(self):
|
||||
|
||||
@torch.jit.script
|
||||
class Color(Enum):
|
||||
RED = 1
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
r"""
|
||||
Test TorchScript exception handling.
|
||||
"""
|
||||
|
||||
|
||||
class TestException(TestCase):
|
||||
def test_pyop_exception_message(self):
|
||||
class Foo(torch.jit.ScriptModule):
|
||||
@ -16,31 +18,40 @@ class TestException(TestCase):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
foo = Foo()
|
||||
# testing that the correct error message propagates
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d"
|
||||
):
|
||||
foo(torch.ones([123])) # wrong size
|
||||
|
||||
def test_builtin_error_messsage(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
|
||||
|
||||
@torch.jit.script
|
||||
def close_match(x):
|
||||
return x.masked_fill(True)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
||||
"supported in TorchScript"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"This op may not exist or may not be currently " "supported in TorchScript",
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def unknown_op(x):
|
||||
torch.set_anomaly_enabled(True)
|
||||
return x
|
||||
|
||||
def test_exceptions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
cu = torch.jit.CompilationUnit(
|
||||
"""
|
||||
def foo(cond):
|
||||
if bool(cond):
|
||||
raise ValueError(3)
|
||||
return 1
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cu.foo(torch.tensor(0))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "3"):
|
||||
@ -97,6 +108,7 @@ class TestException(TestCase):
|
||||
else:
|
||||
raise Exception("Hi")
|
||||
return a
|
||||
|
||||
self.assertEqual(foo(), 1)
|
||||
|
||||
@torch.jit.script
|
||||
@ -114,11 +126,13 @@ class TestException(TestCase):
|
||||
no_message()
|
||||
|
||||
def test_assertions(self):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
cu = torch.jit.CompilationUnit(
|
||||
"""
|
||||
def foo(cond):
|
||||
assert bool(cond), "hi"
|
||||
return 0
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
cu.foo(torch.tensor(1))
|
||||
with self.assertRaisesRegex(torch.jit.Error, "AssertionError: hi"):
|
||||
@ -142,7 +156,9 @@ class TestException(TestCase):
|
||||
def fn(x):
|
||||
return python_op(x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "operation failed in the TorchScript interpreter"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "operation failed in the TorchScript interpreter"
|
||||
):
|
||||
fn(torch.tensor(4))
|
||||
|
||||
def test_dict_expansion_raises_error(self):
|
||||
@ -150,8 +166,9 @@ class TestException(TestCase):
|
||||
d = {"foo": 1, "bar": 2, "baz": 3}
|
||||
return {**d}
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError,
|
||||
"Dict expansion "):
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.frontend.NotSupportedError, "Dict expansion "
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_custom_python_exception(self):
|
||||
@ -162,7 +179,9 @@ class TestException(TestCase):
|
||||
def fn():
|
||||
raise MyValueError("test custom exception")
|
||||
|
||||
with self.assertRaisesRegex(torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"):
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.Error, "jit.test_exception.MyValueError: test custom exception"
|
||||
):
|
||||
fn()
|
||||
|
||||
def test_custom_python_exception_defined_elsewhere(self):
|
||||
@ -171,5 +190,9 @@ class TestException(TestCase):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
raise MyKeyError("This is a user defined key error")
|
||||
with self.assertRaisesRegex(torch.jit.Error, "jit.myexception.MyKeyError: This is a user defined key error"):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.Error,
|
||||
"jit.myexception.MyKeyError: This is a user defined key error",
|
||||
):
|
||||
fn()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestFunctionalBlocks(JitTestCase):
|
||||
def test_subgraph_creation(self):
|
||||
@ -30,14 +33,22 @@ class TestFunctionalBlocks(JitTestCase):
|
||||
return x + y + z
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
self.run_pass('create_functional_graphs', graph)
|
||||
self.run_pass("create_functional_graphs", graph)
|
||||
|
||||
# all uses of x and y should be sunk
|
||||
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph)
|
||||
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph)
|
||||
FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(
|
||||
r"%x"
|
||||
).run(graph)
|
||||
FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(
|
||||
r"%y"
|
||||
).run(graph)
|
||||
|
||||
# Don't allow any outputs which escape scope, so there is one final addition in the graph
|
||||
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph)
|
||||
FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(
|
||||
graph
|
||||
)
|
||||
|
||||
# z + 1, z.add_(2) considered non functional, z = z * z should be considered functional
|
||||
FileCheck().check("add").check("add_").check_not("mul").check("FunctionalGraph").run(graph)
|
||||
FileCheck().check("add").check("add_").check_not("mul").check(
|
||||
"FunctionalGraph"
|
||||
).run(graph)
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
class TestFuserCommon(JitTestCase):
|
||||
def test_autodiff_fallback(self):
|
||||
for rq in [True, False]:
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
return torch.max(x**2.0, x**3.0)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch
|
||||
import torch._C
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
class TestGraphRewritePasses(JitTestCase):
|
||||
|
||||
@ -3,9 +3,9 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
from typing import Tuple, List
|
||||
import torch
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -13,9 +13,12 @@ sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestHash(JitTestCase):
|
||||
def test_hash_tuple(self):
|
||||
@ -38,6 +41,7 @@ class TestHash(JitTestCase):
|
||||
|
||||
def test_hash_tensor(self):
|
||||
"""Tensors should hash by identity"""
|
||||
|
||||
def fn(t1, t2):
|
||||
return hash(t1) == hash(t2)
|
||||
|
||||
@ -74,7 +78,7 @@ class TestHash(JitTestCase):
|
||||
self.checkScript(fn, (1.2345, 6.789))
|
||||
self.checkScript(fn, (1.2345, float("inf")))
|
||||
self.checkScript(fn, (float("inf"), float("inf")))
|
||||
self.checkScript(fn, (1.2345, float('nan')))
|
||||
self.checkScript(fn, (1.2345, float("nan")))
|
||||
if sys.version_info < (3, 10):
|
||||
# Hash of two nans are not guaranteed to be equal. From https://docs.python.org/3/whatsnew/3.10.html :
|
||||
# Hashes of NaN values of both float type and decimal.Decimal type now depend on object identity.
|
||||
@ -103,9 +107,9 @@ class TestHash(JitTestCase):
|
||||
def fn(d1: torch.device, d2: torch.device):
|
||||
return hash(d1) == hash(d2)
|
||||
|
||||
gpu0 = torch.device('cuda:0')
|
||||
gpu1 = torch.device('cuda:1')
|
||||
cpu = torch.device('cpu')
|
||||
gpu0 = torch.device("cuda:0")
|
||||
gpu1 = torch.device("cuda:1")
|
||||
cpu = torch.device("cpu")
|
||||
self.checkScript(fn, (gpu0, gpu0))
|
||||
self.checkScript(fn, (gpu0, gpu1))
|
||||
self.checkScript(fn, (gpu0, cpu))
|
||||
|
||||
@ -6,21 +6,29 @@ import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from jit.test_hooks_modules import (
|
||||
ModuleDirectforwardSubmodCall, ModuleForwardSingleInput,
|
||||
ModuleForwardTupleInput, create_forward_tuple_input,
|
||||
create_module_forward_multiple_inputs, create_module_forward_single_input,
|
||||
create_forward_tuple_input,
|
||||
create_module_forward_multiple_inputs,
|
||||
create_module_forward_single_input,
|
||||
create_module_hook_return_nothing,
|
||||
create_module_multiple_hooks_multiple_inputs,
|
||||
create_module_multiple_hooks_single_input, create_module_no_forward_input,
|
||||
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs,
|
||||
create_module_multiple_hooks_single_input,
|
||||
create_module_no_forward_input,
|
||||
create_module_same_hook_repeated,
|
||||
create_submodule_forward_multiple_inputs,
|
||||
create_submodule_forward_single_input,
|
||||
create_submodule_forward_single_input_return_not_tupled,
|
||||
create_submodule_hook_return_nothing,
|
||||
create_submodule_multiple_hooks_multiple_inputs,
|
||||
create_submodule_multiple_hooks_single_input,
|
||||
create_submodule_no_forward_input, create_submodule_same_hook_repeated,
|
||||
create_submodule_to_call_directly_with_hooks)
|
||||
create_submodule_no_forward_input,
|
||||
create_submodule_same_hook_repeated,
|
||||
create_submodule_to_call_directly_with_hooks,
|
||||
ModuleDirectforwardSubmodCall,
|
||||
ModuleForwardSingleInput,
|
||||
ModuleForwardTupleInput,
|
||||
)
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -37,7 +45,6 @@ if __name__ == "__main__":
|
||||
|
||||
# Tests for JIT forward hooks and pre-hooks
|
||||
class TestHooks(JitTestCase):
|
||||
|
||||
def test_module_no_forward_input(self):
|
||||
self.checkModule(create_module_no_forward_input(), ())
|
||||
|
||||
@ -73,7 +80,8 @@ class TestHooks(JitTestCase):
|
||||
|
||||
def test_submodule_multiple_hooks_multiple_inputs(self):
|
||||
self.checkModule(
|
||||
create_submodule_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook"),
|
||||
create_submodule_multiple_hooks_multiple_inputs(),
|
||||
(["a"], "no_pre_hook"),
|
||||
)
|
||||
|
||||
def test_submodule_forward_single_input(self):
|
||||
@ -242,7 +250,8 @@ class TestHooks(JitTestCase):
|
||||
m.register_forward_pre_hook(pre_hook_wrong_input1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "has the wrong inner types for the input tuple argument",
|
||||
RuntimeError,
|
||||
"has the wrong inner types for the input tuple argument",
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
@ -278,7 +287,8 @@ class TestHooks(JitTestCase):
|
||||
m.register_forward_pre_hook(pre_hook_wrong_output)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "returned the wrong type of: 'int'",
|
||||
RuntimeError,
|
||||
"returned the wrong type of: 'int'",
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SubmoduleNoForwardInputs(torch.nn.Module):
|
||||
def __init__(self, name):
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch._C import parse_ir
|
||||
from torch.testing import FileCheck
|
||||
@ -11,10 +12,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestIgnorableArgs(JitTestCase):
|
||||
@ -44,11 +48,14 @@ class TestIgnorableArgs(JitTestCase):
|
||||
# We ignore trailing arguments after start=2 for dim 0
|
||||
# and after end=1 for dim 1
|
||||
# because in %16, %15 and %0 are default values for the schema.
|
||||
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src)
|
||||
FileCheck().check(
|
||||
"torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)"
|
||||
).run(src)
|
||||
self.assertEqual(function(), function_copy())
|
||||
|
||||
def test_add_out_ignorable_args(self):
|
||||
@torch.jit.script
|
||||
def fn(x: torch.Tensor, y: torch.Tensor):
|
||||
torch.add(x, y, out=y)
|
||||
|
||||
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)
|
||||
|
||||
@ -9,13 +9,16 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestIgnoreContextManager(JitTestCase):
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
@ -26,11 +29,14 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
b: int = 5
|
||||
c: int = 0
|
||||
d: int = 6
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"):
|
||||
with torch.jit._IgnoreContextManager(
|
||||
a="inp:int", b="inp:int", c="out:int", d="out:int"
|
||||
):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
c = l[0] + a + b
|
||||
d = 9
|
||||
return c + d
|
||||
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), model())
|
||||
@ -41,10 +47,13 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
a: int = 4
|
||||
b: int = 5
|
||||
c: int = 0
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"):
|
||||
with torch.jit._IgnoreContextManager(
|
||||
a="inp:int", b="inp:int", c="out:int"
|
||||
):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
c = l[0] + a + b
|
||||
return c
|
||||
|
||||
model = B()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 11)
|
||||
@ -58,6 +67,7 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
b = l[0] + a
|
||||
return b
|
||||
|
||||
model = C()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 6)
|
||||
@ -72,6 +82,7 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
|
||||
l = [2 + b for i in range(a) if i > 2]
|
||||
return a
|
||||
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 4)
|
||||
@ -85,6 +96,7 @@ class TestIgnoreContextManager(JitTestCase):
|
||||
c = [2 for i in range(7) if i > 2]
|
||||
c[0] = 3
|
||||
return c[0] + c[1]
|
||||
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 5)
|
||||
|
||||
@ -2,10 +2,10 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from typing import List, Any, Dict, Tuple, Optional
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -19,6 +19,7 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests for torch.jit.isinstance
|
||||
class TestIsinstance(JitTestCase):
|
||||
def test_int(self):
|
||||
@ -223,28 +224,42 @@ class TestIsinstance(JitTestCase):
|
||||
|
||||
x = ["1", "2", "3"]
|
||||
|
||||
err_msg = "Attempted to use List without a contained type. " \
|
||||
err_msg = (
|
||||
"Attempted to use List without a contained type. "
|
||||
r"Please add a contained type, e.g. List\[int\]"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
torch.jit.script(list_no_contained_type)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
list_no_contained_type(x)
|
||||
|
||||
|
||||
|
||||
def test_tuple_no_contained_type(self):
|
||||
def tuple_no_contained_type(x: Any):
|
||||
assert torch.jit.isinstance(x, Tuple)
|
||||
|
||||
x = ("1", "2", "3")
|
||||
|
||||
err_msg = "Attempted to use Tuple without a contained type. " \
|
||||
err_msg = (
|
||||
"Attempted to use Tuple without a contained type. "
|
||||
r"Please add a contained type, e.g. Tuple\[int\]"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
torch.jit.script(tuple_no_contained_type)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
tuple_no_contained_type(x)
|
||||
|
||||
def test_optional_no_contained_type(self):
|
||||
@ -253,12 +268,20 @@ class TestIsinstance(JitTestCase):
|
||||
|
||||
x = ("1", "2", "3")
|
||||
|
||||
err_msg = "Attempted to use Optional without a contained type. " \
|
||||
err_msg = (
|
||||
"Attempted to use Optional without a contained type. "
|
||||
r"Please add a contained type, e.g. Optional\[int\]"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
torch.jit.script(optional_no_contained_type)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
optional_no_contained_type(x)
|
||||
|
||||
def test_dict_no_contained_type(self):
|
||||
@ -267,12 +290,20 @@ class TestIsinstance(JitTestCase):
|
||||
|
||||
x = {"a": "aa"}
|
||||
|
||||
err_msg = "Attempted to use Dict without contained types. " \
|
||||
err_msg = (
|
||||
"Attempted to use Dict without contained types. "
|
||||
r"Please add contained type, e.g. Dict\[int, int\]"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
torch.jit.script(dict_no_contained_type)
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg,):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
err_msg,
|
||||
):
|
||||
dict_no_contained_type(x)
|
||||
|
||||
def test_tuple_rhs(self):
|
||||
|
||||
@ -13,10 +13,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests various JIT-related utility functions.
|
||||
class TestJitUtils(JitTestCase):
|
||||
@ -24,58 +27,71 @@ class TestJitUtils(JitTestCase):
|
||||
def test_get_callable_argument_names_positional_or_keyword(self):
|
||||
def fn_positional_or_keyword_args_only(x, y):
|
||||
return x + y
|
||||
|
||||
self.assertEqual(
|
||||
["x", "y"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_positional_or_keyword_args_only))
|
||||
torch._jit_internal.get_callable_argument_names(
|
||||
fn_positional_or_keyword_args_only
|
||||
),
|
||||
)
|
||||
|
||||
# Tests that POSITIONAL_ONLY arguments are ignored.
|
||||
def test_get_callable_argument_names_positional_only(self):
|
||||
code = dedent('''
|
||||
code = dedent(
|
||||
"""
|
||||
def fn_positional_only_arg(x, /, y):
|
||||
return x + y
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
fn_positional_only_arg = jit_utils._get_py3_code(code, 'fn_positional_only_arg')
|
||||
fn_positional_only_arg = jit_utils._get_py3_code(code, "fn_positional_only_arg")
|
||||
self.assertEqual(
|
||||
["y"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg))
|
||||
torch._jit_internal.get_callable_argument_names(fn_positional_only_arg),
|
||||
)
|
||||
|
||||
# Tests that VAR_POSITIONAL arguments are ignored.
|
||||
def test_get_callable_argument_names_var_positional(self):
|
||||
# Tests that VAR_POSITIONAL arguments are ignored.
|
||||
def fn_var_positional_arg(x, *arg):
|
||||
return x + arg[0]
|
||||
|
||||
self.assertEqual(
|
||||
["x"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg))
|
||||
torch._jit_internal.get_callable_argument_names(fn_var_positional_arg),
|
||||
)
|
||||
|
||||
# Tests that KEYWORD_ONLY arguments are ignored.
|
||||
def test_get_callable_argument_names_keyword_only(self):
|
||||
def fn_keyword_only_arg(x, *, y):
|
||||
return x + y
|
||||
|
||||
self.assertEqual(
|
||||
["x"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg))
|
||||
["x"], torch._jit_internal.get_callable_argument_names(fn_keyword_only_arg)
|
||||
)
|
||||
|
||||
# Tests that VAR_KEYWORD arguments are ignored.
|
||||
def test_get_callable_argument_names_var_keyword(self):
|
||||
def fn_var_keyword_arg(**args):
|
||||
return args['x'] + args['y']
|
||||
return args["x"] + args["y"]
|
||||
|
||||
self.assertEqual(
|
||||
[],
|
||||
torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg))
|
||||
[], torch._jit_internal.get_callable_argument_names(fn_var_keyword_arg)
|
||||
)
|
||||
|
||||
# Tests that a function signature containing various different types of
|
||||
# arguments are ignored.
|
||||
def test_get_callable_argument_names_hybrid(self):
|
||||
code = dedent('''
|
||||
code = dedent(
|
||||
"""
|
||||
def fn_hybrid_args(x, /, y, *args, **kwargs):
|
||||
return x + y + args[0] + kwargs['z']
|
||||
''')
|
||||
fn_hybrid_args = jit_utils._get_py3_code(code, 'fn_hybrid_args')
|
||||
"""
|
||||
)
|
||||
fn_hybrid_args = jit_utils._get_py3_code(code, "fn_hybrid_args")
|
||||
self.assertEqual(
|
||||
["y"],
|
||||
torch._jit_internal.get_callable_argument_names(fn_hybrid_args))
|
||||
["y"], torch._jit_internal.get_callable_argument_names(fn_hybrid_args)
|
||||
)
|
||||
|
||||
def test_checkscriptassertraisesregex(self):
|
||||
def fn():
|
||||
@ -84,22 +100,18 @@ class TestJitUtils(JitTestCase):
|
||||
|
||||
self.checkScriptRaisesRegex(fn, (), Exception, "range", name="fn")
|
||||
|
||||
s = dedent("""
|
||||
s = dedent(
|
||||
"""
|
||||
def fn():
|
||||
tup = (1, 2)
|
||||
return tup[2]
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
|
||||
|
||||
def test_no_tracer_warn_context_manager(self):
|
||||
torch._C._jit_set_tracer_state_warn(True)
|
||||
with jit_utils.NoTracerWarnContextManager() as no_warn:
|
||||
self.assertEqual(
|
||||
False,
|
||||
torch._C._jit_get_tracer_state_warn()
|
||||
)
|
||||
self.assertEqual(
|
||||
True,
|
||||
torch._C._jit_get_tracer_state_warn()
|
||||
)
|
||||
self.assertEqual(False, torch._C._jit_get_tracer_state_warn())
|
||||
self.assertEqual(True, torch._C._jit_get_tracer_state_warn())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestLogging(JitTestCase):
|
||||
def test_bump_numeric_counter(self):
|
||||
@ -22,30 +25,29 @@ class TestLogging(JitTestCase):
|
||||
def forward(self, x):
|
||||
for i in range(x.size(0)):
|
||||
x += 1.0
|
||||
torch.jit._logging.add_stat_value('foo', 1)
|
||||
torch.jit._logging.add_stat_value("foo", 1)
|
||||
|
||||
if bool(x.sum() > 0.0):
|
||||
torch.jit._logging.add_stat_value('positive', 1)
|
||||
torch.jit._logging.add_stat_value("positive", 1)
|
||||
else:
|
||||
torch.jit._logging.add_stat_value('negative', 1)
|
||||
torch.jit._logging.add_stat_value("negative", 1)
|
||||
return x
|
||||
|
||||
logger = torch.jit._logging.LockingLogger()
|
||||
old_logger = torch.jit._logging.set_logger(logger)
|
||||
try:
|
||||
|
||||
mtl = ModuleThatLogs()
|
||||
for i in range(5):
|
||||
mtl(torch.rand(3, 4, 5))
|
||||
|
||||
self.assertEqual(logger.get_counter_val('foo'), 15)
|
||||
self.assertEqual(logger.get_counter_val('positive'), 5)
|
||||
self.assertEqual(logger.get_counter_val("foo"), 15)
|
||||
self.assertEqual(logger.get_counter_val("positive"), 5)
|
||||
finally:
|
||||
torch.jit._logging.set_logger(old_logger)
|
||||
|
||||
def test_trace_numeric_counter(self):
|
||||
def foo(x):
|
||||
torch.jit._logging.add_stat_value('foo', 1)
|
||||
torch.jit._logging.add_stat_value("foo", 1)
|
||||
return x + 1.0
|
||||
|
||||
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
||||
@ -54,7 +56,7 @@ class TestLogging(JitTestCase):
|
||||
try:
|
||||
traced(torch.rand(3, 4))
|
||||
|
||||
self.assertEqual(logger.get_counter_val('foo'), 1)
|
||||
self.assertEqual(logger.get_counter_val("foo"), 1)
|
||||
finally:
|
||||
torch.jit._logging.set_logger(old_logger)
|
||||
|
||||
@ -65,7 +67,7 @@ class TestLogging(JitTestCase):
|
||||
for i in range(30):
|
||||
x += 1.0
|
||||
tp_end = torch.jit._logging.time_point()
|
||||
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
|
||||
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
||||
return x
|
||||
|
||||
mtm = ModuleThatTimes()
|
||||
@ -73,7 +75,7 @@ class TestLogging(JitTestCase):
|
||||
old_logger = torch.jit._logging.set_logger(logger)
|
||||
try:
|
||||
mtm(torch.rand(3, 4))
|
||||
self.assertGreater(logger.get_counter_val('mytimer'), 0)
|
||||
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
||||
finally:
|
||||
torch.jit._logging.set_logger(old_logger)
|
||||
|
||||
@ -85,7 +87,7 @@ class TestLogging(JitTestCase):
|
||||
for i in range(30):
|
||||
x += 1.0
|
||||
tp_end = torch.jit._logging.time_point()
|
||||
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
|
||||
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
|
||||
return x
|
||||
|
||||
mtm = ModuleThatTimes()
|
||||
@ -93,27 +95,27 @@ class TestLogging(JitTestCase):
|
||||
old_logger = torch.jit._logging.set_logger(logger)
|
||||
try:
|
||||
mtm(torch.rand(3, 4))
|
||||
self.assertGreater(logger.get_counter_val('mytimer'), 0)
|
||||
self.assertGreater(logger.get_counter_val("mytimer"), 0)
|
||||
finally:
|
||||
torch.jit._logging.set_logger(old_logger)
|
||||
|
||||
def test_counter_aggregation(self):
|
||||
def foo(x):
|
||||
for i in range(3):
|
||||
torch.jit._logging.add_stat_value('foo', 1)
|
||||
torch.jit._logging.add_stat_value("foo", 1)
|
||||
return x + 1.0
|
||||
|
||||
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
||||
logger = torch.jit._logging.LockingLogger()
|
||||
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
|
||||
logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
|
||||
old_logger = torch.jit._logging.set_logger(logger)
|
||||
try:
|
||||
traced(torch.rand(3, 4))
|
||||
|
||||
self.assertEqual(logger.get_counter_val('foo'), 1)
|
||||
self.assertEqual(logger.get_counter_val("foo"), 1)
|
||||
finally:
|
||||
torch.jit._logging.set_logger(old_logger)
|
||||
|
||||
def test_logging_levels_set(self):
|
||||
torch._C._jit_set_logging_option('foo')
|
||||
self.assertEqual('foo', torch._C._jit_get_logging_option())
|
||||
torch._C._jit_set_logging_option("foo")
|
||||
self.assertEqual("foo", torch._C._jit_get_logging_option())
|
||||
|
||||
@ -1,28 +1,32 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from torch.testing import FileCheck
|
||||
from torch import jit
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.testing._internal.jit_utils
|
||||
import torch.nn as nn
|
||||
import unittest
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch import jit
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import freeze_rng_state
|
||||
from torch.testing._internal.jit_utils import RUN_CUDA_HALF
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global, RUN_CUDA_HALF
|
||||
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestMisc(JitTestCase):
|
||||
def test_joined_str(self):
|
||||
@ -30,12 +34,12 @@ class TestMisc(JitTestCase):
|
||||
hello, test = "Hello", "test"
|
||||
print(f"{hello + ' ' + test}, I'm a {test}")
|
||||
print("format blank")
|
||||
hi = 'hi'
|
||||
hi = "hi"
|
||||
print(f"stuff before {hi}")
|
||||
print(f"{hi} stuff after")
|
||||
return x + 1
|
||||
|
||||
x = torch.arange(4., requires_grad=True)
|
||||
x = torch.arange(4.0, requires_grad=True)
|
||||
# TODO: Add support for f-strings in string parser frontend
|
||||
# self.checkScript(func, [x], optimize=True, capture_output=True)
|
||||
|
||||
@ -50,10 +54,14 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(captured, captured_script)
|
||||
|
||||
def test_kwarg_support(self):
|
||||
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"):
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.frontend.NotSupportedError, "variable number of arguments"
|
||||
):
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, *, n_tokens: int, device_name: str = 2):
|
||||
pass
|
||||
|
||||
torch.jit.script(M())
|
||||
|
||||
class M(torch.nn.Module):
|
||||
@ -62,32 +70,35 @@ class TestMisc(JitTestCase):
|
||||
|
||||
sm = torch.jit.script(M())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "missing value for argument 'n_tokens'"
|
||||
):
|
||||
sm()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "positional arg"):
|
||||
sm(3, 'hello')
|
||||
sm(3, "hello")
|
||||
|
||||
self.assertEqual(sm(n_tokens=3, device_name='hello'), (3, 'hello'))
|
||||
self.assertEqual(sm(n_tokens=3, device_name="hello"), (3, "hello"))
|
||||
|
||||
def test_tuple_subscripted_assign(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "subscripted assignment"):
|
||||
|
||||
@torch.jit.script
|
||||
def foo(a: Tuple[int, int]) -> None:
|
||||
a[0] = a[1]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "augmented assignment"):
|
||||
|
||||
@torch.jit.script
|
||||
def bar(a: Tuple[int, int]) -> None:
|
||||
a[0] += a[1]
|
||||
|
||||
def test_subexpression_List_Future(self):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x: List[torch.jit.Future[int]]) -> torch.jit.Future[int]:
|
||||
return x[0]
|
||||
|
||||
FileCheck().check('Future[int]').check('Future[int]').run(fn.graph)
|
||||
FileCheck().check("Future[int]").check("Future[int]").run(fn.graph)
|
||||
|
||||
def test_subexpression_Future_annotate(self):
|
||||
@torch.jit.script
|
||||
@ -110,36 +121,40 @@ class TestMisc(JitTestCase):
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
return "foo"
|
||||
|
||||
forward = torch.jit.script(forward)
|
||||
self.assertEqual(forward(1), "foo")
|
||||
self.assertEqual(forward("bar"), "bar")
|
||||
|
||||
def test_subexpression_Tuple_int_int_Future(self):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x: Tuple[int, int, torch.jit.Future[int]]) -> Tuple[int, torch.jit.Future[int]]:
|
||||
def fn(
|
||||
x: Tuple[int, int, torch.jit.Future[int]]
|
||||
) -> Tuple[int, torch.jit.Future[int]]:
|
||||
return x[0], x[2]
|
||||
|
||||
FileCheck().check('(int, int, Future[int])').check('(int, Future[int])').run(fn.graph)
|
||||
FileCheck().check("(int, int, Future[int])").check("(int, Future[int])").run(
|
||||
fn.graph
|
||||
)
|
||||
|
||||
def test_subexpression_Dict_int_Future(self):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x: Dict[int, torch.jit.Future[int]], y: int) -> torch.jit.Future[int]:
|
||||
return x[y]
|
||||
|
||||
FileCheck().check('Dict(int, Future(int))').check('Future[int]').run(fn.graph)
|
||||
FileCheck().check("Dict(int, Future(int))").check("Future[int]").run(fn.graph)
|
||||
|
||||
def test_subexpression_Optional(self):
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x: Optional[Dict[int, torch.jit.Future[int]]]) -> Optional[torch.jit.Future[int]]:
|
||||
def fn(
|
||||
x: Optional[Dict[int, torch.jit.Future[int]]]
|
||||
) -> Optional[torch.jit.Future[int]]:
|
||||
if x is not None:
|
||||
return x[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
FileCheck().check('Dict(int, Future(int))?').run(fn.graph)
|
||||
FileCheck().check("Dict(int, Future(int))?").run(fn.graph)
|
||||
|
||||
def test_if_returning_any(self):
|
||||
"""
|
||||
@ -147,6 +162,7 @@ class TestMisc(JitTestCase):
|
||||
types early from each branch when the return
|
||||
type of the function is Any.
|
||||
"""
|
||||
|
||||
def if_function(inp: torch.Tensor) -> Any:
|
||||
if inp.shape[0] == 1:
|
||||
return inp * inp
|
||||
@ -156,14 +172,23 @@ class TestMisc(JitTestCase):
|
||||
self.checkScript(if_function, (torch.randn(5),))
|
||||
|
||||
def test_hacked_twin(self):
|
||||
|
||||
def gen_data():
|
||||
with freeze_rng_state():
|
||||
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
||||
|
||||
input, index, value, = gen_data()
|
||||
input1, index1, value1, = gen_data()
|
||||
out1 = torch.ops.aten.index_put.hacked_twin(input, [index], value, accumulate=False)
|
||||
(
|
||||
input,
|
||||
index,
|
||||
value,
|
||||
) = gen_data()
|
||||
(
|
||||
input1,
|
||||
index1,
|
||||
value1,
|
||||
) = gen_data()
|
||||
out1 = torch.ops.aten.index_put.hacked_twin(
|
||||
input, [index], value, accumulate=False
|
||||
)
|
||||
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@ -172,14 +197,23 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(input, input1)
|
||||
|
||||
def test_unsafe_hacked_twin(self):
|
||||
|
||||
def gen_data():
|
||||
with freeze_rng_state():
|
||||
return torch.randn(10), torch.randint(10, (20,)), torch.randn(20)
|
||||
|
||||
input, index, value, = gen_data()
|
||||
input1, index1, value1, = gen_data()
|
||||
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(input, [index], value, accumulate=False)
|
||||
(
|
||||
input,
|
||||
index,
|
||||
value,
|
||||
) = gen_data()
|
||||
(
|
||||
input1,
|
||||
index1,
|
||||
value1,
|
||||
) = gen_data()
|
||||
out1 = torch.ops.aten._unsafe_index_put.hacked_twin(
|
||||
input, [index], value, accumulate=False
|
||||
)
|
||||
out2 = torch.index_put(input1, [index1], value1, accumulate=False)
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
@ -188,7 +222,9 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(input, input1)
|
||||
|
||||
def index_put_fn(input, index, value):
|
||||
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
||||
return torch.ops.aten._unsafe_index_put(
|
||||
input, [index], value, accumulate=False
|
||||
)
|
||||
|
||||
input2, index2, value2 = gen_data()
|
||||
script_index_put_fn = torch.jit.script(index_put_fn)
|
||||
@ -197,7 +233,9 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def index_fn(input, index, value):
|
||||
return torch.ops.aten._unsafe_index_put(input, [index], value, accumulate=False)
|
||||
return torch.ops.aten._unsafe_index_put(
|
||||
input, [index], value, accumulate=False
|
||||
)
|
||||
|
||||
script_index_fn = torch.jit.script(index_fn)
|
||||
expect = index_fn(input2.clone(), index2, value2)
|
||||
@ -205,7 +243,6 @@ class TestMisc(JitTestCase):
|
||||
self.assertEqual(expect, actual)
|
||||
|
||||
def test_export_opnames_interface(self):
|
||||
|
||||
@torch.jit.interface
|
||||
class OneTwoModule(nn.Module):
|
||||
def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
@ -240,7 +277,7 @@ class TestMisc(JitTestCase):
|
||||
make_global(OneTwoModule)
|
||||
|
||||
class M(nn.Module):
|
||||
sub : OneTwoModule
|
||||
sub: OneTwoModule
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -254,12 +291,18 @@ class TestMisc(JitTestCase):
|
||||
|
||||
torch._C._enable_mobile_interface_call_export()
|
||||
scripted_M_mod = torch.jit.script(M())
|
||||
self.assertTrue({'aten::mul.Scalar', 'aten::mul.Tensor', 'aten::reciprocal'}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))))
|
||||
self.assertTrue(
|
||||
{"aten::mul.Scalar", "aten::mul.Tensor", "aten::reciprocal"}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))
|
||||
)
|
||||
)
|
||||
|
||||
scripted_M_mod.sub = torch.jit.script(FooMod())
|
||||
self.assertTrue({'aten::add.Tensor', 'aten::mul.Scalar'}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))))
|
||||
self.assertTrue(
|
||||
{"aten::add.Tensor", "aten::mul.Scalar"}.issubset(
|
||||
set(torch.jit.export_opnames(scripted_M_mod))
|
||||
)
|
||||
)
|
||||
|
||||
def test_math_inf(self):
|
||||
from math import inf
|
||||
@ -292,7 +335,6 @@ class TestMisc(JitTestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.jit.script(non_temporary_fail)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def test_return():
|
||||
return []
|
||||
@ -335,7 +377,9 @@ class TestMisc(JitTestCase):
|
||||
def multiple_args():
|
||||
return torch.LongTensor(1, [2])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "multiple positional arguments that were not all integers"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "multiple positional arguments that were not all integers"
|
||||
):
|
||||
torch.jit.script(multiple_args)
|
||||
|
||||
# kwarg bad schema
|
||||
@ -345,7 +389,6 @@ class TestMisc(JitTestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "hello"):
|
||||
torch.jit.script(bad_kwarg)
|
||||
|
||||
|
||||
def test_broadcasting_list(self):
|
||||
"""
|
||||
Test BroadcastingList and torch.nn._size_N_t alias
|
||||
@ -360,7 +403,7 @@ class TestMisc(JitTestCase):
|
||||
return x[0] + x[1]
|
||||
|
||||
self.assertTrue(torch.jit.script(sum_i)(4) == 8)
|
||||
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)
|
||||
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.0)
|
||||
|
||||
def test_parse_ir_annotate(self):
|
||||
ir = """
|
||||
@ -397,7 +440,6 @@ class TestMisc(JitTestCase):
|
||||
self.assertTrue(ret.numel() == 1)
|
||||
self.assertTrue(len(ret.size()) == 1)
|
||||
|
||||
|
||||
def test_script_many_decorators(self):
|
||||
def no_op_decorator(f):
|
||||
return f
|
||||
@ -410,7 +452,9 @@ class TestMisc(JitTestCase):
|
||||
def foo(x, dim: int):
|
||||
return x.unsqueeze(dim)
|
||||
|
||||
x = torch.randn(1,)
|
||||
x = torch.randn(
|
||||
1,
|
||||
)
|
||||
expected = foo(x, 0)
|
||||
scripted = torch.jit.script(foo)
|
||||
actual = scripted(x, 0)
|
||||
@ -421,10 +465,10 @@ class TestMisc(JitTestCase):
|
||||
# https://github.com/pytorch/pytorch/issues/75476
|
||||
def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor:
|
||||
p = torch.sigmoid(p)
|
||||
result = p ** gamma
|
||||
result = p**gamma
|
||||
return result
|
||||
|
||||
x = torch.rand((2, 2), dtype=torch.half, device='cuda')
|
||||
x = torch.rand((2, 2), dtype=torch.half, device="cuda")
|
||||
|
||||
ref = fn(x)
|
||||
|
||||
@ -450,8 +494,12 @@ class TestMisc(JitTestCase):
|
||||
# We want "Scalar" to come before "complex".
|
||||
op, override_names = torch._C._jit_get_operation("aten::add")
|
||||
print(override_names)
|
||||
complex_indices = [i for i, name in enumerate(override_names) if name == "complex"]
|
||||
Scalar_indices = [i for i, name in enumerate(override_names) if name == "Scalar"]
|
||||
complex_indices = [
|
||||
i for i, name in enumerate(override_names) if name == "complex"
|
||||
]
|
||||
Scalar_indices = [
|
||||
i for i, name in enumerate(override_names) if name == "Scalar"
|
||||
]
|
||||
|
||||
self.assertTrue(len(complex_indices) > 0)
|
||||
self.assertTrue(len(Scalar_indices) > 0)
|
||||
|
||||
@ -3,27 +3,33 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, ProfilingMode,
|
||||
set_default_dtype,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_utils import (
|
||||
enable_profiling_mode_for_profiling_tests,
|
||||
GRAPH_EXECUTOR,
|
||||
ProfilingMode,
|
||||
set_default_dtype,
|
||||
)
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
from torch.testing._internal.common_utils import slowTest, suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
HAS_TORCHVISION = True
|
||||
except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
@ -31,6 +37,7 @@ except RuntimeError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
|
||||
class MnistNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -49,6 +56,7 @@ class MnistNet(nn.Module):
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
class TestModels(JitTestCase):
|
||||
@staticmethod
|
||||
def _test_dcgan_models(self, device, check_export_import=True):
|
||||
@ -102,31 +110,38 @@ class TestModels(JitTestCase):
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
# state size. (ndf*8) x 4 x 4
|
||||
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
||||
nn.Sigmoid()
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.main(input).view(-1, 1).squeeze(1)
|
||||
|
||||
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
|
||||
self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
|
||||
(torch.rand(bs, nz, 1, 1, device=device),),
|
||||
export_import=check_export_import)
|
||||
example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
|
||||
self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
|
||||
export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
DCGANGenerator(nz, ngf, nc).to(device),
|
||||
(torch.rand(bs, nz, 1, 1, device=device),),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
example_input = DCGANGenerator(nz, ngf, nc).to(device)(
|
||||
torch.rand(bs, nz, 1, 1, device=device)
|
||||
)
|
||||
self.checkTrace(
|
||||
DCGANDiscriminator(nc, ndf).to(device),
|
||||
(example_input,),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
def test_dcgan_models(self):
|
||||
# Note: Can sometimes fail with low precision if run with float dtype
|
||||
with set_default_dtype(torch.double):
|
||||
self._test_dcgan_models(self, device='cpu')
|
||||
self._test_dcgan_models(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_dcgan_models_cuda(self):
|
||||
# Note: Can sometimes fail with low precision if run with float dtype
|
||||
with set_default_dtype(torch.double):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_dcgan_models(self, device='cuda', check_export_import=False)
|
||||
self._test_dcgan_models(self, device="cuda", check_export_import=False)
|
||||
|
||||
@staticmethod
|
||||
def _test_neural_style(self, device, check_export_import=True):
|
||||
@ -147,9 +162,13 @@ class TestModels(JitTestCase):
|
||||
self.res4 = ResidualBlock(128)
|
||||
self.res5 = ResidualBlock(128)
|
||||
# Upsampling Layers
|
||||
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
|
||||
self.deconv1 = UpsampleConvLayer(
|
||||
128, 64, kernel_size=3, stride=1, upsample=2
|
||||
)
|
||||
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
|
||||
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
|
||||
self.deconv2 = UpsampleConvLayer(
|
||||
64, 32, kernel_size=3, stride=1, upsample=2
|
||||
)
|
||||
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
|
||||
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
|
||||
# Non-linearities
|
||||
@ -174,7 +193,9 @@ class TestModels(JitTestCase):
|
||||
super().__init__()
|
||||
reflection_padding = kernel_size // 2
|
||||
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
||||
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
||||
self.conv2d = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.reflection_pad(x)
|
||||
@ -209,14 +230,20 @@ class TestModels(JitTestCase):
|
||||
ref: http://distill.pub/2016/deconv-checkerboard/
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride, upsample=None
|
||||
):
|
||||
super().__init__()
|
||||
self.upsample = upsample
|
||||
if upsample:
|
||||
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
|
||||
self.upsample_layer = torch.nn.Upsample(
|
||||
mode="nearest", scale_factor=upsample
|
||||
)
|
||||
reflection_padding = kernel_size // 2
|
||||
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
||||
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
||||
self.conv2d = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
@ -226,44 +253,54 @@ class TestModels(JitTestCase):
|
||||
out = self.conv2d(out)
|
||||
return out
|
||||
|
||||
self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
TransformerNet(),
|
||||
(torch.rand(5, 3, 16, 16),),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
@slowTest
|
||||
def test_neural_style(self):
|
||||
self._test_neural_style(self, device='cpu')
|
||||
self._test_neural_style(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_neural_style_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_neural_style(self, device='cuda', check_export_import=False)
|
||||
self._test_neural_style(self, device="cuda", check_export_import=False)
|
||||
|
||||
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor")
|
||||
@unittest.skipIf(
|
||||
GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor"
|
||||
)
|
||||
@staticmethod
|
||||
def _test_mnist(self, device, check_export_import=True):
|
||||
# eval() is present because dropout makes this nondeterministic
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
|
||||
export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
MnistNet().to(device).eval(),
|
||||
(torch.rand(5, 1, 28, 28, device=device),),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
def test_mnist(self):
|
||||
self._test_mnist(self, device='cpu')
|
||||
self._test_mnist(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_mnist_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_mnist(self, device='cuda', check_export_import=False)
|
||||
self._test_mnist(self, device="cuda", check_export_import=False)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_mnist_training_leaks_no_memory_cuda(self):
|
||||
net = MnistNet().cuda()
|
||||
# MnistNet uses dropout, don't check its trace
|
||||
traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
|
||||
check_trace=False)
|
||||
traced_net = torch.jit.trace(
|
||||
net, [torch.randn(5, 1, 28, 28, device="cuda")], check_trace=False
|
||||
)
|
||||
|
||||
def train(iters):
|
||||
for _ in range(iters):
|
||||
# Get some fake data
|
||||
inp = torch.randn(5, 1, 28, 28, device='cuda')
|
||||
inp = torch.randn(5, 1, 28, 28, device="cuda")
|
||||
out = traced_net(inp)
|
||||
|
||||
# Here's some fake loss
|
||||
@ -292,21 +329,23 @@ class TestModels(JitTestCase):
|
||||
return F.softmax(action_scores, dim=1)
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
|
||||
export_import=test_export_import)
|
||||
self.checkTrace(
|
||||
Policy().to(device),
|
||||
(torch.rand(1, 4, device=device),),
|
||||
export_import=test_export_import,
|
||||
)
|
||||
|
||||
def test_reinforcement_learning(self):
|
||||
self._test_reinforcement_learning(self, device='cpu')
|
||||
self._test_reinforcement_learning(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_reinforcement_learning_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
|
||||
self._test_reinforcement_learning(self, device="cuda", test_export_import=False)
|
||||
|
||||
@staticmethod
|
||||
def _test_snli(self, device, check_export_import=True):
|
||||
class Bottle(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
if len(input.size()) <= 2:
|
||||
return super().forward(input)
|
||||
@ -318,25 +357,31 @@ class TestModels(JitTestCase):
|
||||
pass
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
input_size = config.d_proj if config.projection else config.d_embed
|
||||
dropout = 0 if config.n_layers == 1 else config.dp_ratio
|
||||
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
|
||||
num_layers=config.n_layers, dropout=dropout,
|
||||
bidirectional=config.birnn)
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=config.d_hidden,
|
||||
num_layers=config.n_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=config.birnn,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size()[1]
|
||||
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
|
||||
h0 = c0 = inputs.new_zeros(state_shape)
|
||||
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
|
||||
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
|
||||
return (
|
||||
ht[-1]
|
||||
if not self.config.birnn
|
||||
else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
|
||||
)
|
||||
|
||||
class SNLIClassifier(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -359,7 +404,8 @@ class TestModels(JitTestCase):
|
||||
Linear(*lin_config),
|
||||
self.relu,
|
||||
self.dropout,
|
||||
Linear(seq_in_size, config.d_out))
|
||||
Linear(seq_in_size, config.d_out),
|
||||
)
|
||||
|
||||
def forward(self, premise, hypothesis):
|
||||
prem_embed = self.embed(premise)
|
||||
@ -391,22 +437,25 @@ class TestModels(JitTestCase):
|
||||
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
|
||||
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
|
||||
|
||||
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
|
||||
inputs_require_grads=False, export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
SNLIClassifier(Config()).to(device),
|
||||
(premise, hypothesis),
|
||||
inputs_require_grads=False,
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
@slowTest
|
||||
def test_snli(self):
|
||||
self._test_snli(self, device='cpu')
|
||||
self._test_snli(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_snli_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_snli(self, device='cuda', check_export_import=False)
|
||||
self._test_snli(self, device="cuda", check_export_import=False)
|
||||
|
||||
@staticmethod
|
||||
def _test_super_resolution(self, device, check_export_import=True):
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self, upscale_factor):
|
||||
super().__init__()
|
||||
|
||||
@ -414,7 +463,7 @@ class TestModels(JitTestCase):
|
||||
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
||||
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
||||
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
||||
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
||||
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
def forward(self, x):
|
||||
@ -425,17 +474,20 @@ class TestModels(JitTestCase):
|
||||
return x
|
||||
|
||||
net = Net(upscale_factor=4).to(device)
|
||||
self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
|
||||
export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
net,
|
||||
(torch.rand(5, 1, 32, 32, device=device),),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
@slowTest
|
||||
def test_super_resolution(self):
|
||||
self._test_super_resolution(self, device='cpu')
|
||||
self._test_super_resolution(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, 'no CUDA')
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_super_resolution_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_super_resolution(self, device='cuda', check_export_import=False)
|
||||
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
||||
|
||||
@suppress_warnings
|
||||
def test_time_sequence_prediction(self):
|
||||
@ -485,8 +537,7 @@ class TestModels(JitTestCase):
|
||||
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
|
||||
with torch._jit_internal._disable_emit_hooks():
|
||||
# TODO: toggle export_import once above issues are fixed
|
||||
self.checkTrace(Traced(), (torch.rand(3, 4),),
|
||||
export_import=False)
|
||||
self.checkTrace(Traced(), (torch.rand(3, 4),), export_import=False)
|
||||
|
||||
@staticmethod
|
||||
def _test_vae(self, device, check_export_import=True):
|
||||
@ -523,22 +574,27 @@ class TestModels(JitTestCase):
|
||||
|
||||
with enable_profiling_mode_for_profiling_tests():
|
||||
# eval() is present because randn_like makes this nondeterministic
|
||||
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
|
||||
export_import=check_export_import)
|
||||
self.checkTrace(
|
||||
VAE().to(device).eval(),
|
||||
(torch.rand(128, 1, 28, 28, device=device),),
|
||||
export_import=check_export_import,
|
||||
)
|
||||
|
||||
def test_vae(self):
|
||||
self._test_vae(self, device='cpu')
|
||||
self._test_vae(self, device="cpu")
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_vae_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
self._test_vae(self, device='cuda', check_export_import=False)
|
||||
self._test_vae(self, device="cuda", check_export_import=False)
|
||||
|
||||
@slowTest
|
||||
@skipIfNoTorchVision
|
||||
def test_script_module_trace_resnet18(self):
|
||||
x = torch.ones(1, 3, 224, 224)
|
||||
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
|
||||
m_orig = torch.jit.trace(
|
||||
torchvision.models.resnet18(), torch.ones(1, 3, 224, 224)
|
||||
)
|
||||
m_import = self.getExportImportCopy(m_orig)
|
||||
|
||||
input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
||||
@ -559,16 +615,24 @@ class TestModels(JitTestCase):
|
||||
def test_script_module_script_resnet(self):
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
return nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=False
|
||||
)
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
class BasicBlock(torch.jit.ScriptModule):
|
||||
expansion = 1
|
||||
__constants__ = ['downsample']
|
||||
__constants__ = ["downsample"]
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super().__init__()
|
||||
@ -600,13 +664,14 @@ class TestModels(JitTestCase):
|
||||
return out
|
||||
|
||||
class ResNet(torch.jit.ScriptModule):
|
||||
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
|
||||
__constants__ = ["layer1", "layer2", "layer3", "layer4"]
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000):
|
||||
super().__init__()
|
||||
self.inplanes = 64
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
@ -619,7 +684,9 @@ class TestModels(JitTestCase):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
@ -679,8 +746,10 @@ class TestModels(JitTestCase):
|
||||
x = torch.ones(1, 3, 224, 224)
|
||||
model = torchvision.models.AlexNet()
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
g, outputs, inputs = torch.jit._get_trace_graph(model, x, return_inputs=True)
|
||||
self.run_pass('cse', g)
|
||||
g, outputs, inputs = torch.jit._get_trace_graph(
|
||||
model, x, return_inputs=True
|
||||
)
|
||||
self.run_pass("cse", g)
|
||||
m = self.createFunctionFromGraph(g)
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
self.assertEqual(outputs, m(*inputs))
|
||||
|
||||
@ -1,19 +1,23 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleAPIs(JitTestCase):
|
||||
def test_default_state_dict_methods(self):
|
||||
@ -52,18 +56,23 @@ class TestModuleAPIs(JitTestCase):
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor],
|
||||
prefix: str, keep_vars: bool):
|
||||
def _save_to_state_dict(
|
||||
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
|
||||
):
|
||||
self.customized_save_state_dict_called = True
|
||||
return {"dummy": torch.ones(1)}
|
||||
|
||||
@torch.jit.export
|
||||
def _load_from_state_dict(self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
prefix: str, local_metadata: Any,
|
||||
strict: bool, missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str]):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata: Any,
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
):
|
||||
self.customized_load_state_dict_called = True
|
||||
return
|
||||
|
||||
@ -94,18 +103,23 @@ class TestModuleAPIs(JitTestCase):
|
||||
return x
|
||||
|
||||
@torch.jit.export
|
||||
def _save_to_state_dict(self, destination: Dict[str, torch.Tensor],
|
||||
prefix: str, keep_vars: bool):
|
||||
def _save_to_state_dict(
|
||||
self, destination: Dict[str, torch.Tensor], prefix: str, keep_vars: bool
|
||||
):
|
||||
self.customized_save_state_dict_called = True
|
||||
return {"dummy": torch.ones(1)}
|
||||
|
||||
@torch.jit.export
|
||||
def _load_from_state_dict(self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
prefix: str, local_metadata: Any,
|
||||
strict: bool, missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str]):
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
prefix: str,
|
||||
local_metadata: Any,
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
):
|
||||
self.customized_load_state_dict_called = True
|
||||
return
|
||||
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
@ -13,10 +14,13 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModuleContainers(JitTestCase):
|
||||
def test_sequential_intermediary_types(self):
|
||||
@ -54,11 +58,13 @@ class TestModuleContainers(JitTestCase):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
modules = OrderedDict([
|
||||
('one', Inner()),
|
||||
('two', Inner2()),
|
||||
('three', Inner3()),
|
||||
])
|
||||
modules = OrderedDict(
|
||||
[
|
||||
("one", Inner()),
|
||||
("two", Inner2()),
|
||||
("three", Inner3()),
|
||||
]
|
||||
)
|
||||
self.moduledict = nn.ModuleDict(modules)
|
||||
|
||||
def forward(self, x, skip_name):
|
||||
@ -115,7 +121,6 @@ class TestModuleContainers(JitTestCase):
|
||||
|
||||
return x, x2, names, iter
|
||||
|
||||
|
||||
for name in ["", "one", "two", "three"]:
|
||||
inp = torch.tensor(1)
|
||||
self.checkModule(M(), (inp, name))
|
||||
@ -136,7 +141,7 @@ class TestModuleContainers(JitTestCase):
|
||||
x = mod(x)
|
||||
return x - 5
|
||||
|
||||
self.checkModule(CustomSequential(), (torch.tensor(.5),))
|
||||
self.checkModule(CustomSequential(), (torch.tensor(0.5),))
|
||||
|
||||
class CustomModuleList(nn.ModuleList):
|
||||
def __init__(self):
|
||||
@ -148,16 +153,19 @@ class TestModuleContainers(JitTestCase):
|
||||
x = mod(x)
|
||||
return x - 5
|
||||
|
||||
self.checkModule(CustomModuleList(), (torch.tensor(.5),))
|
||||
self.checkModule(CustomModuleList(), (torch.tensor(0.5),))
|
||||
|
||||
class CustomModuleDict(nn.ModuleDict):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
OrderedDict([
|
||||
('one', Inner()),
|
||||
('two', nn.ReLU()),
|
||||
('three', Inner()),
|
||||
]))
|
||||
OrderedDict(
|
||||
[
|
||||
("one", Inner()),
|
||||
("two", nn.ReLU()),
|
||||
("three", Inner()),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + 3
|
||||
@ -167,7 +175,7 @@ class TestModuleContainers(JitTestCase):
|
||||
names.append(name)
|
||||
return names, x - 5
|
||||
|
||||
self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
|
||||
self.checkModule(CustomModuleDict(), (torch.tensor(0.5),))
|
||||
|
||||
def test_script_module_list_sequential(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
@ -225,7 +233,9 @@ class TestModuleContainers(JitTestCase):
|
||||
def forward(self, v):
|
||||
return self.mods[-11].forward(v)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception, "Index -11 out of range", "self.mods[-11]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, "Index -11 out of range", "self.mods[-11]"
|
||||
):
|
||||
torch.jit.script(M2())
|
||||
|
||||
class M3(M):
|
||||
@ -233,7 +243,9 @@ class TestModuleContainers(JitTestCase):
|
||||
i = 3
|
||||
return self.mods[i].forward(v)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception, "Enumeration is supported", "self.mods[i]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, "Enumeration is supported", "self.mods[i]"
|
||||
):
|
||||
torch.jit.script(M3())
|
||||
|
||||
class M4(M):
|
||||
@ -273,17 +285,23 @@ class TestModuleContainers(JitTestCase):
|
||||
self.moduledict = CustomModuleDict({"submod": self.submod})
|
||||
|
||||
def forward(self, inputs):
|
||||
assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
|
||||
assert (
|
||||
self.modulelist[0] is self.submod
|
||||
), "__getitem__ failing for ModuleList"
|
||||
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
|
||||
for module in self.modulelist:
|
||||
assert module is self.submod, "__iter__ failing for ModuleList"
|
||||
|
||||
assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
|
||||
assert (
|
||||
self.sequential[0] is self.submod
|
||||
), "__getitem__ failing for Sequential"
|
||||
assert len(self.sequential) == 1, "__len__ failing for Sequential"
|
||||
for module in self.sequential:
|
||||
assert module is self.submod, "__iter__ failing for Sequential"
|
||||
|
||||
assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
|
||||
assert (
|
||||
self.moduledict["submod"] is self.submod
|
||||
), "__getitem__ failing for ModuleDict"
|
||||
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
|
||||
|
||||
# note: unable to index moduledict with a string variable currently
|
||||
@ -345,12 +363,13 @@ class TestModuleContainers(JitTestCase):
|
||||
super().__init__()
|
||||
self.relu = torch.jit.script(torch.nn.ReLU())
|
||||
self.tanh = torch.jit.script(torch.nn.Tanh())
|
||||
self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
|
||||
"tanh": self.tanh})
|
||||
self.moduledict = torch.nn.ModuleDict(
|
||||
{"relu": self.relu, "tanh": self.tanh}
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
assert self.moduledict['relu'] is self.relu
|
||||
assert self.moduledict['tanh'] is self.tanh
|
||||
assert self.moduledict["relu"] is self.relu
|
||||
assert self.moduledict["tanh"] is self.tanh
|
||||
return input
|
||||
|
||||
m = MyModule()
|
||||
@ -360,31 +379,34 @@ class TestModuleContainers(JitTestCase):
|
||||
class BadModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moduledict = torch.nn.ModuleDict({"foo": None,
|
||||
"bar": None})
|
||||
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
|
||||
|
||||
def forward(self, input):
|
||||
assert self.moduledict['blah'] == "blah", "this is a keyerror"
|
||||
assert self.moduledict["blah"] == "blah", "this is a keyerror"
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Key Error, blah", "self.moduledict['blah'"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Key Error, blah", 'self.moduledict["blah"'
|
||||
):
|
||||
b = BadModule()
|
||||
torch.jit.script(b)
|
||||
|
||||
class AnotherBadModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.moduledict = torch.nn.ModuleDict({"foo": None,
|
||||
"bar": None})
|
||||
self.moduledict = torch.nn.ModuleDict({"foo": None, "bar": None})
|
||||
|
||||
def forward(self, input):
|
||||
idx = 'blah'
|
||||
idx = "blah"
|
||||
assert self.moduledict[idx] == "blah", "this is a string literal error"
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "Unable to extract string literal index. "
|
||||
"ModuleDict indexing is only supported with string literals. "
|
||||
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
|
||||
"because i is not a literal.",
|
||||
"self.moduledict[idx]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Unable to extract string literal index. "
|
||||
"ModuleDict indexing is only supported with string literals. "
|
||||
"For example, 'i = \"a\"; self.layers\\[i\\]\\(x\\)' will fail "
|
||||
"because i is not a literal.",
|
||||
"self.moduledict[idx]",
|
||||
):
|
||||
b = AnotherBadModule()
|
||||
torch.jit.script(b)
|
||||
|
||||
@ -393,6 +415,7 @@ class TestModuleContainers(JitTestCase):
|
||||
Test that an attempt to script a module with a regular list attribute
|
||||
containing other modules fails with a relevant error message.
|
||||
"""
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -422,7 +445,9 @@ class TestModuleContainers(JitTestCase):
|
||||
self.moduledict = CustomModuleDict()
|
||||
|
||||
def forward(self, inputs):
|
||||
assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
|
||||
assert (
|
||||
"submod" not in self.moduledict
|
||||
), "__contains__ fails for ModuleDict"
|
||||
return inputs
|
||||
|
||||
m = MyModule()
|
||||
@ -433,6 +458,7 @@ class TestModuleContainers(JitTestCase):
|
||||
Test that a type annotation can be provided for a ModuleDict that allows
|
||||
non-static indexing.
|
||||
"""
|
||||
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(torch.nn.Module):
|
||||
def forward(self, inp: Any) -> Any:
|
||||
@ -485,7 +511,9 @@ class TestModuleContainers(JitTestCase):
|
||||
submodule: ModuleInterface = self.d[key]
|
||||
return submodule.forward(x)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"
|
||||
):
|
||||
torch.jit.script(ModWithWrongAnnotation())
|
||||
|
||||
def test_typed_module_list(self):
|
||||
@ -493,6 +521,7 @@ class TestModuleContainers(JitTestCase):
|
||||
Test that a type annotation can be provided for a ModuleList that allows
|
||||
non-static indexing.
|
||||
"""
|
||||
|
||||
@torch.jit.interface
|
||||
class ModuleInterface(torch.nn.Module):
|
||||
def forward(self, inp: Any) -> Any:
|
||||
@ -545,7 +574,9 @@ class TestModuleContainers(JitTestCase):
|
||||
submodule: ModuleInterface = self.l[idx]
|
||||
return submodule.forward(x)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"
|
||||
):
|
||||
torch.jit.script(ModWithWrongAnnotation())
|
||||
|
||||
def test_module_properties(self):
|
||||
@ -596,10 +627,34 @@ class TestModuleContainers(JitTestCase):
|
||||
def attr(self):
|
||||
return self.a + 1
|
||||
|
||||
self.checkModule(ModuleWithProperties(5), (5, 6,))
|
||||
self.checkModule(ModuleWithProperties(5), (-5, -6,))
|
||||
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
|
||||
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
|
||||
self.checkModule(
|
||||
ModuleWithProperties(5),
|
||||
(
|
||||
5,
|
||||
6,
|
||||
),
|
||||
)
|
||||
self.checkModule(
|
||||
ModuleWithProperties(5),
|
||||
(
|
||||
-5,
|
||||
-6,
|
||||
),
|
||||
)
|
||||
self.checkModule(
|
||||
ModuleWithNoSetter(5),
|
||||
(
|
||||
5,
|
||||
6,
|
||||
),
|
||||
)
|
||||
self.checkModule(
|
||||
ModuleWithNoSetter(5),
|
||||
(
|
||||
-5,
|
||||
-6,
|
||||
),
|
||||
)
|
||||
|
||||
mod = ModuleWithProperties(3)
|
||||
scripted_mod = torch.jit.script(mod)
|
||||
@ -625,7 +680,6 @@ class TestModuleContainers(JitTestCase):
|
||||
def forward(self, x):
|
||||
return self.linear(self.linear(x))
|
||||
|
||||
|
||||
class N(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -659,7 +713,9 @@ class TestModuleContainers(JitTestCase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
|
||||
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
|
||||
self.parameter_list = nn.ParameterList(
|
||||
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
self.module_list[0]
|
||||
@ -673,7 +729,9 @@ class TestModuleContainers(JitTestCase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
|
||||
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
|
||||
self.parameter_list = nn.ParameterList(
|
||||
[nn.Parameter(torch.zeros(1)) for _ in range(10)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
r = x
|
||||
@ -687,9 +745,14 @@ class TestModuleContainers(JitTestCase):
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parameter_dict = nn.ParameterDict({k: nn.Parameter(torch.zeros(1)) for k in ['a', 'b', 'c']})
|
||||
self.parameter_dict = nn.ParameterDict(
|
||||
{k: nn.Parameter(torch.zeros(1)) for k in ["a", "b", "c"]}
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.parameter_dict['a'] * x + self.parameter_dict['b'] * self.parameter_dict['c']
|
||||
return (
|
||||
self.parameter_dict["a"] * x
|
||||
+ self.parameter_dict["b"] * self.parameter_dict["c"]
|
||||
)
|
||||
|
||||
self.checkModule(MyModule(), (torch.ones(1),))
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from typing import List, Any
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
@ -12,10 +13,13 @@ from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class OrigModule(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
@ -27,6 +31,7 @@ class OrigModule(nn.Module):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return input + self.one(input, input) + 1
|
||||
|
||||
|
||||
class NewModule(nn.Module):
|
||||
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
|
||||
return inp1 * inp2 + 1
|
||||
@ -34,6 +39,7 @@ class NewModule(nn.Module):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self.one(input, input + 1)
|
||||
|
||||
|
||||
class TestModuleInterface(JitTestCase):
|
||||
def test_not_submodule_interface_call(self):
|
||||
@torch.jit.interface
|
||||
@ -42,7 +48,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestNotModuleInterfaceCall(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -51,7 +57,9 @@ class TestModuleInterface(JitTestCase):
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return self.proxy_mod.two(input)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.proxy_mod.two"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "object has no attribute or method", "self.proxy_mod.two"
|
||||
):
|
||||
torch.jit.script(TestNotModuleInterfaceCall())
|
||||
|
||||
def test_module_interface(self):
|
||||
@ -108,17 +116,37 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
scripted_foo_mod = torch.jit.script(FooMod())
|
||||
scripted_bar_mod = torch.jit.script(BarMod())
|
||||
self.checkScript(use_module_interface,
|
||||
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
|
||||
self.checkScript(use_class_interface,
|
||||
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
|
||||
self.checkScript(
|
||||
use_module_interface,
|
||||
(
|
||||
[scripted_foo_mod, scripted_bar_mod],
|
||||
torch.rand(3, 4),
|
||||
),
|
||||
)
|
||||
self.checkScript(
|
||||
use_class_interface,
|
||||
(
|
||||
[scripted_foo_mod, scripted_bar_mod],
|
||||
torch.rand(3, 4),
|
||||
),
|
||||
)
|
||||
|
||||
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
|
||||
def call_module_interface_on_other_method(
|
||||
mod_interface: OneTwoModule, x: Tensor
|
||||
) -> Tensor:
|
||||
return mod_interface.forward2(x)
|
||||
|
||||
# ensure error out when we call the module on the method other than the interface specified.
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "mod_interface.forward2"):
|
||||
self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),))
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "object has no attribute or method", "mod_interface.forward2"
|
||||
):
|
||||
self.checkScript(
|
||||
call_module_interface_on_other_method,
|
||||
(
|
||||
scripted_bar_mod,
|
||||
torch.rand(3, 4),
|
||||
),
|
||||
)
|
||||
|
||||
def test_module_doc_string(self):
|
||||
@torch.jit.interface
|
||||
@ -135,7 +163,7 @@ class TestModuleInterface(JitTestCase):
|
||||
r"""stuff 3"""
|
||||
|
||||
class TestModule(nn.Module):
|
||||
proxy_mod : TestInterface
|
||||
proxy_mod: TestInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -178,7 +206,9 @@ class TestModuleInterface(JitTestCase):
|
||||
return self.one(self.two(x), x)
|
||||
|
||||
# check class object is not a subtype of module interface
|
||||
with self.assertRaisesRegex(RuntimeError, "ScriptModule class can be subtype of module interface"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "ScriptModule class can be subtype of module interface"
|
||||
):
|
||||
as_module_interface(Foo())
|
||||
|
||||
class WrongMod(nn.Module):
|
||||
@ -233,9 +263,11 @@ class TestModuleInterface(JitTestCase):
|
||||
as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
|
||||
as_any_to_any(torch.jit.script(AnyToAnyImpl()))
|
||||
|
||||
|
||||
def test_module_interface_inheritance(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "does not support inheritance yet. Please directly"
|
||||
):
|
||||
|
||||
@torch.jit.interface
|
||||
class InheritMod(nn.ReLU):
|
||||
def three(self, x: Tensor) -> Tensor:
|
||||
@ -251,7 +283,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -269,7 +301,9 @@ class TestModuleInterface(JitTestCase):
|
||||
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
|
||||
|
||||
# module swap with non-scripted module should throw error
|
||||
with self.assertRaisesRegex(RuntimeError, "a ScriptModule with non-scripted module"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "a ScriptModule with non-scripted module"
|
||||
):
|
||||
scripted_mod.proxy_mod = NewModule()
|
||||
|
||||
def test_module_swap_wrong_module(self):
|
||||
@ -286,7 +320,7 @@ class TestModuleInterface(JitTestCase):
|
||||
return input + 1
|
||||
|
||||
class TestModule(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -310,7 +344,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -358,9 +392,11 @@ class TestModuleInterface(JitTestCase):
|
||||
# proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
|
||||
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
|
||||
# proxy_mod is neither a module interface or have the same JIT type, should fail
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
|
||||
r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' "
|
||||
+ r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'",
|
||||
):
|
||||
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
|
||||
|
||||
def test_script_module_as_interface_swap(self):
|
||||
@ -391,7 +427,7 @@ class TestModuleInterface(JitTestCase):
|
||||
return self.one(input, input + 1)
|
||||
|
||||
class TestNNModuleWithScriptModule(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -432,7 +468,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
proxy_mod : ModInterface
|
||||
proxy_mod: ModInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -480,7 +516,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
proxy_mod : ModInterface
|
||||
proxy_mod: ModInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -523,7 +559,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
proxy_mod : ModInterface
|
||||
proxy_mod: ModInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -568,7 +604,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
proxy_mod : ModInterface
|
||||
proxy_mod: ModInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -583,7 +619,9 @@ class TestModuleInterface(JitTestCase):
|
||||
|
||||
m = torch.jit.script(TestModule())
|
||||
m.eval()
|
||||
with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type."):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Freezing does not support SetAttr on an interface type."
|
||||
):
|
||||
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
|
||||
|
||||
def test_freeze_module_with_interface_and_fork(self):
|
||||
@ -610,7 +648,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
proxy_mod : ModInterface
|
||||
proxy_mod: ModInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -644,7 +682,7 @@ class TestModuleInterface(JitTestCase):
|
||||
pass
|
||||
|
||||
class TestModule(nn.Module):
|
||||
proxy_mod : ModuleInterface
|
||||
proxy_mod: ModuleInterface
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestModules(JitTestCase):
|
||||
def test_script_module_with_constants_list(self):
|
||||
|
||||
@ -4,10 +4,13 @@ import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestOpDecompositions(JitTestCase):
|
||||
def test_op_decomposition(self):
|
||||
@ -31,7 +34,9 @@ class TestOpDecompositions(JitTestCase):
|
||||
def square_decomp(x):
|
||||
return torch.pow(x, 2)
|
||||
|
||||
torch.jit._register_decomposition(torch.ops.aten.square.default, square_decomp.graph)
|
||||
torch.jit._register_decomposition(
|
||||
torch.ops.aten.square.default, square_decomp.graph
|
||||
)
|
||||
torch._C._jit_pass_run_decompositions(foo.graph)
|
||||
FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph)
|
||||
x = torch.rand([4])
|
||||
|
||||
@ -3,8 +3,9 @@
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import skipIfNoXNNPACK
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
def check_replacement(
|
||||
@ -133,10 +134,8 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
"prepacked::linear_clamp_run": "aten::linear",
|
||||
"prepacked::conv2d_clamp_prepack": "aten::conv2d",
|
||||
"prepacked::conv2d_clamp_run": "aten::conv2d",
|
||||
"prepacked::conv2d_transpose_clamp_prepack":
|
||||
"aten::conv_transpose2d",
|
||||
"prepacked::conv2d_transpose_clamp_run":
|
||||
"aten::conv_transpose2d",
|
||||
"prepacked::conv2d_transpose_clamp_prepack": "aten::conv_transpose2d",
|
||||
"prepacked::conv2d_transpose_clamp_run": "aten::conv_transpose2d",
|
||||
},
|
||||
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
|
||||
)
|
||||
@ -147,7 +146,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)),
|
||||
replacements={
|
||||
"prepacked::linear_clamp_prepack": "aten::linear",
|
||||
"prepacked::linear_clamp_run": "aten::linear"
|
||||
"prepacked::linear_clamp_run": "aten::linear",
|
||||
},
|
||||
jit_pass=torch._C._jit_pass_insert_prepacked_ops,
|
||||
)
|
||||
@ -223,11 +222,9 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
self.check_replacement(
|
||||
model=model,
|
||||
replacements={
|
||||
"prepacked::linear_clamp_prepack":
|
||||
"prepacked::linear_clamp_prepack",
|
||||
"prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack",
|
||||
"prepacked::linear_clamp_run": linear_activation_kind,
|
||||
"prepacked::conv2d_clamp_prepack":
|
||||
"prepacked::conv2d_clamp_prepack",
|
||||
"prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack",
|
||||
"prepacked::conv2d_clamp_run": conv2d_activation_kind,
|
||||
},
|
||||
jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv,
|
||||
@ -239,7 +236,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
linear_activation=F.hardtanh,
|
||||
linear_activation_kind="aten::hardtanh",
|
||||
conv2d_activation=F.hardtanh_,
|
||||
conv2d_activation_kind="aten::hardtanh_"
|
||||
conv2d_activation_kind="aten::hardtanh_",
|
||||
)
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@ -248,7 +245,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
linear_activation=F.hardtanh_,
|
||||
linear_activation_kind="aten::hardtanh_",
|
||||
conv2d_activation=F.hardtanh,
|
||||
conv2d_activation_kind="aten::hardtanh"
|
||||
conv2d_activation_kind="aten::hardtanh",
|
||||
)
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@ -257,7 +254,7 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
linear_activation=F.relu,
|
||||
linear_activation_kind="aten::relu",
|
||||
conv2d_activation=F.relu_,
|
||||
conv2d_activation_kind="aten::relu_"
|
||||
conv2d_activation_kind="aten::relu_",
|
||||
)
|
||||
|
||||
@skipIfNoXNNPACK
|
||||
@ -266,5 +263,5 @@ class TestOptimizeForMobilePreserveDebugInfo(JitTestCase):
|
||||
linear_activation=F.relu_,
|
||||
linear_activation_kind="aten::relu_",
|
||||
conv2d_activation=F.relu,
|
||||
conv2d_activation_kind="aten::relu"
|
||||
conv2d_activation_kind="aten::relu",
|
||||
)
|
||||
|
||||
@ -2,16 +2,19 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
from torch import nn
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestParametrization(JitTestCase):
|
||||
# Define some parametrization
|
||||
@ -29,7 +32,7 @@ class TestParametrization(JitTestCase):
|
||||
|
||||
# Check the tracing works. Because traced functions cannot be called
|
||||
# directly, we run the comparison on the activations.
|
||||
traced_model = torch.jit.trace_module(model, {'forward': x})
|
||||
traced_model = torch.jit.trace_module(model, {"forward": x})
|
||||
y_hat = traced_model(x)
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
@ -39,10 +42,9 @@ class TestParametrization(JitTestCase):
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
# Check the tracing throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'Cannot trace a model while caching'):
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot trace a model while caching"):
|
||||
with parametrize.cached():
|
||||
traced_model = torch.jit.trace_module(model, {'forward': x})
|
||||
traced_model = torch.jit.trace_module(model, {"forward": x})
|
||||
|
||||
def test_scriptable(self):
|
||||
# TODO: Need to fix the scripting in parametrizations
|
||||
@ -65,5 +67,5 @@ class TestParametrization(JitTestCase):
|
||||
self.assertEqual(y, y_hat)
|
||||
|
||||
# Check the scripting process throws an error when caching
|
||||
with self.assertRaisesRegex(RuntimeError, 'Caching is not implemented'):
|
||||
with self.assertRaisesRegex(RuntimeError, "Caching is not implemented"):
|
||||
scripted_model = torch.jit.trace_module(model)
|
||||
|
||||
@ -2,18 +2,22 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED
|
||||
from typing import List, Dict, Tuple, Any, Optional, NamedTuple # noqa: F401
|
||||
from torch.testing._internal.common_utils import NoTest
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if not _IS_MONKEYTYPE_INSTALLED:
|
||||
print("monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr)
|
||||
print(
|
||||
"monkeytype is not installed. Skipping tests for Profile-Directed Typing",
|
||||
file=sys.stderr,
|
||||
)
|
||||
JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -23,10 +27,12 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPDT(JitTestCase):
|
||||
"""
|
||||
A suite of tests for profile directed typing in TorchScript.
|
||||
"""
|
||||
|
||||
def test_nn_module(self):
|
||||
class TestPDTModel(torch.nn.Module):
|
||||
def forward(self, x) -> Any:
|
||||
@ -39,8 +45,14 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(TestPDTModel)
|
||||
pdt_model = TestPDTModel()
|
||||
inp: List[Tuple[Any, ...]] = [(20, ), (2.7, ), (False, ), ]
|
||||
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
|
||||
inp: List[Tuple[Any, ...]] = [
|
||||
(20,),
|
||||
(2.7,),
|
||||
(False,),
|
||||
]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
pdt_model, example_inputs={pdt_model: inp}
|
||||
)
|
||||
self.assertEqual(scripted_pdt_model(50), pdt_model(50))
|
||||
self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8))
|
||||
self.assertTrue(scripted_pdt_model(True), pdt_model(True))
|
||||
@ -63,8 +75,10 @@ class TestPDT(JitTestCase):
|
||||
make_global(NestedPDTInner, NestedModulePDTWrapper)
|
||||
inner_pdt_model = NestedPDTInner()
|
||||
wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model)
|
||||
inp: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
||||
scripted_pdt_model = torch.jit.script(wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp})
|
||||
inp: List[Tuple[Any, ...]] = [(20,), (False,)]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp}
|
||||
)
|
||||
self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30))
|
||||
self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9))
|
||||
self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True))
|
||||
@ -87,10 +101,18 @@ class TestPDT(JitTestCase):
|
||||
make_global(NestedModulePDTInner, NestedModulePDTOuter)
|
||||
inner_pdt_model = NestedModulePDTInner()
|
||||
outer_pdt_model = NestedModulePDTOuter(inner_pdt_model)
|
||||
inner_input: List[Tuple[Any, ...]] = [(10, 10), (1.9, 20), ]
|
||||
outer_input: List[Tuple[Any, ...]] = [(20, ), (False, )]
|
||||
scripted_pdt_model = torch.jit.script(outer_pdt_model, example_inputs={inner_pdt_model: inner_input,
|
||||
outer_pdt_model: outer_input, })
|
||||
inner_input: List[Tuple[Any, ...]] = [
|
||||
(10, 10),
|
||||
(1.9, 20),
|
||||
]
|
||||
outer_input: List[Tuple[Any, ...]] = [(20,), (False,)]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
outer_pdt_model,
|
||||
example_inputs={
|
||||
inner_pdt_model: inner_input,
|
||||
outer_pdt_model: outer_input,
|
||||
},
|
||||
)
|
||||
self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30))
|
||||
self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9))
|
||||
self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True))
|
||||
@ -109,8 +131,10 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(NestedFunctionInForward)
|
||||
pdt_model = NestedFunctionInForward()
|
||||
inp: List[Tuple[Any, ...]] = [(-1, ), (False, )]
|
||||
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model: inp})
|
||||
inp: List[Tuple[Any, ...]] = [(-1,), (False,)]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
pdt_model, example_inputs={pdt_model: inp}
|
||||
)
|
||||
self.assertEqual(scripted_pdt_model(30), pdt_model(30))
|
||||
self.assertEqual(scripted_pdt_model(True), pdt_model(True))
|
||||
|
||||
@ -126,14 +150,26 @@ class TestPDT(JitTestCase):
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
make_global(TestModelWithExport)
|
||||
pdt_model = TestModelWithExport()
|
||||
inp: List[Tuple[Any, ...]] = [(20, 10, ), (2.7, 8.9, ), ]
|
||||
scripted_pdt_model = torch.jit.script(pdt_model, example_inputs={pdt_model.fn: inp})
|
||||
inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
20,
|
||||
10,
|
||||
),
|
||||
(
|
||||
2.7,
|
||||
8.9,
|
||||
),
|
||||
]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
pdt_model, example_inputs={pdt_model.fn: inp}
|
||||
)
|
||||
self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90))
|
||||
self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2))
|
||||
self.assertTrue(scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2))
|
||||
self.assertTrue(
|
||||
scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2)
|
||||
)
|
||||
|
||||
def test_class_methods(self):
|
||||
class PDTModel:
|
||||
@ -142,10 +178,34 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(PDTModel)
|
||||
pdt_model = PDTModel()
|
||||
inp: List[Tuple[Any, ...]] = [([10, 20, ], ), ]
|
||||
scripted_pdt_model = torch.jit.script(PDTModel, example_inputs={pdt_model.test_sum: inp})
|
||||
inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
[
|
||||
10,
|
||||
20,
|
||||
],
|
||||
),
|
||||
]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
PDTModel, example_inputs={pdt_model.test_sum: inp}
|
||||
)
|
||||
script_model = scripted_pdt_model()
|
||||
self.assertEqual(script_model.test_sum([10, 20, 30, ], ), pdt_model.test_sum([10, 20, 30, ], ))
|
||||
self.assertEqual(
|
||||
script_model.test_sum(
|
||||
[
|
||||
10,
|
||||
20,
|
||||
30,
|
||||
],
|
||||
),
|
||||
pdt_model.test_sum(
|
||||
[
|
||||
10,
|
||||
20,
|
||||
30,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def test_class_with_multiple_methods(self):
|
||||
class PDTModelWithManyMethods:
|
||||
@ -160,14 +220,64 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(PDTModelWithManyMethods)
|
||||
pdt_model = PDTModelWithManyMethods()
|
||||
list_inp: List[Tuple[Any, ...]] = [([1.2, 2.3, ], ), ]
|
||||
str_inp: List[Tuple[Any, ...]] = [("abc", "b", ), ]
|
||||
scripted_pdt_model = torch.jit.script(PDTModelWithManyMethods, example_inputs={pdt_model.test_list_to_dict: list_inp,
|
||||
pdt_model.test_substring: str_inp})
|
||||
list_inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
[
|
||||
1.2,
|
||||
2.3,
|
||||
],
|
||||
),
|
||||
]
|
||||
str_inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
"abc",
|
||||
"b",
|
||||
),
|
||||
]
|
||||
scripted_pdt_model = torch.jit.script(
|
||||
PDTModelWithManyMethods,
|
||||
example_inputs={
|
||||
pdt_model.test_list_to_dict: list_inp,
|
||||
pdt_model.test_substring: str_inp,
|
||||
},
|
||||
)
|
||||
script_model = scripted_pdt_model()
|
||||
self.assertEqual(script_model.test_list_to_dict([1.1, 2.2, 3.3, ], ), pdt_model.test_list_to_dict([1.1, 2.2, 3.3, ], ))
|
||||
self.assertEqual(script_model.test_substring("helloworld", "world", ), pdt_model.test_substring("helloworld", "world", ))
|
||||
self.assertEqual(script_model.test_substring("helloworld", "def", ), pdt_model.test_substring("helloworld", "def", ))
|
||||
self.assertEqual(
|
||||
script_model.test_list_to_dict(
|
||||
[
|
||||
1.1,
|
||||
2.2,
|
||||
3.3,
|
||||
],
|
||||
),
|
||||
pdt_model.test_list_to_dict(
|
||||
[
|
||||
1.1,
|
||||
2.2,
|
||||
3.3,
|
||||
],
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
script_model.test_substring(
|
||||
"helloworld",
|
||||
"world",
|
||||
),
|
||||
pdt_model.test_substring(
|
||||
"helloworld",
|
||||
"world",
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
script_model.test_substring(
|
||||
"helloworld",
|
||||
"def",
|
||||
),
|
||||
pdt_model.test_substring(
|
||||
"helloworld",
|
||||
"def",
|
||||
),
|
||||
)
|
||||
|
||||
def test_multiple_class_with_same_method(self):
|
||||
class PDTModelOne:
|
||||
@ -181,16 +291,69 @@ class TestPDT(JitTestCase):
|
||||
make_global(PDTModelOne, PDTModelTwo)
|
||||
pdt_model_one = PDTModelOne()
|
||||
pdt_model_two = PDTModelTwo()
|
||||
dict_inp: List[Tuple[Any, ...]] = [({1.2: True, 2.3: False, }, 1.2), ]
|
||||
list_inp: List[Tuple[Any, ...]] = [(["abc", "b", ], "c"), ]
|
||||
scripted_pdt_model_one = torch.jit.script(PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp})
|
||||
scripted_pdt_model_two = torch.jit.script(PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp})
|
||||
dict_inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
{
|
||||
1.2: True,
|
||||
2.3: False,
|
||||
},
|
||||
1.2,
|
||||
),
|
||||
]
|
||||
list_inp: List[Tuple[Any, ...]] = [
|
||||
(
|
||||
[
|
||||
"abc",
|
||||
"b",
|
||||
],
|
||||
"c",
|
||||
),
|
||||
]
|
||||
scripted_pdt_model_one = torch.jit.script(
|
||||
PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp}
|
||||
)
|
||||
scripted_pdt_model_two = torch.jit.script(
|
||||
PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp}
|
||||
)
|
||||
|
||||
script_model_one, script_model_two = scripted_pdt_model_one(), scripted_pdt_model_two()
|
||||
self.assertEqual(script_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4),
|
||||
pdt_model_one.test_find({1.1: True, 2.2: True, 3.3: False, }, 4.4))
|
||||
self.assertEqual(script_model_two.test_find(["hello", "world", ], "world"),
|
||||
pdt_model_two.test_find(["hello", "world", ], "world"))
|
||||
script_model_one, script_model_two = (
|
||||
scripted_pdt_model_one(),
|
||||
scripted_pdt_model_two(),
|
||||
)
|
||||
self.assertEqual(
|
||||
script_model_one.test_find(
|
||||
{
|
||||
1.1: True,
|
||||
2.2: True,
|
||||
3.3: False,
|
||||
},
|
||||
4.4,
|
||||
),
|
||||
pdt_model_one.test_find(
|
||||
{
|
||||
1.1: True,
|
||||
2.2: True,
|
||||
3.3: False,
|
||||
},
|
||||
4.4,
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
script_model_two.test_find(
|
||||
[
|
||||
"hello",
|
||||
"world",
|
||||
],
|
||||
"world",
|
||||
),
|
||||
pdt_model_two.test_find(
|
||||
[
|
||||
"hello",
|
||||
"world",
|
||||
],
|
||||
"world",
|
||||
),
|
||||
)
|
||||
|
||||
def test_pdt(self):
|
||||
def test_sum(a, b):
|
||||
@ -218,7 +381,9 @@ class TestPDT(JitTestCase):
|
||||
return torch.complex(real, img)
|
||||
|
||||
make_global(test_args_complex)
|
||||
scripted_fn_complex = torch.jit.script(test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))])
|
||||
scripted_fn_complex = torch.jit.script(
|
||||
test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))]
|
||||
)
|
||||
arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4)
|
||||
self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2))
|
||||
|
||||
@ -248,25 +413,49 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(test_list_and_tuple)
|
||||
|
||||
scripted_fn_float_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([4.9, 8.9],)])
|
||||
self.assertEqual(scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]))
|
||||
scripted_fn_float_list_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[([4.9, 8.9],)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6])
|
||||
)
|
||||
|
||||
scripted_fn_bool_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([True, False, True],)])
|
||||
self.assertEqual(scripted_fn_bool_list_input([True, True, True]), test_list_and_tuple([True, True, True]))
|
||||
scripted_fn_bool_list_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[([True, False, True],)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_bool_list_input([True, True, True]),
|
||||
test_list_and_tuple([True, True, True]),
|
||||
)
|
||||
|
||||
scripted_fn_int_list_input = torch.jit.script(test_list_and_tuple, example_inputs=[([3, 4, 5], )])
|
||||
self.assertEqual(scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]))
|
||||
scripted_fn_int_list_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[([3, 4, 5],)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3])
|
||||
)
|
||||
|
||||
scripted_fn_float_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((4.9, 8.9),)])
|
||||
self.assertEqual(scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)))
|
||||
scripted_fn_float_tuple_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[((4.9, 8.9),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6))
|
||||
)
|
||||
|
||||
scripted_fn_bool_tuple_input = torch.jit.script(test_list_and_tuple,
|
||||
example_inputs=[((True, False, True),)])
|
||||
self.assertEqual(scripted_fn_bool_tuple_input((True, True, True)),
|
||||
test_list_and_tuple((True, True, True)))
|
||||
scripted_fn_bool_tuple_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[((True, False, True),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_bool_tuple_input((True, True, True)),
|
||||
test_list_and_tuple((True, True, True)),
|
||||
)
|
||||
|
||||
scripted_fn_int_tuple_input = torch.jit.script(test_list_and_tuple, example_inputs=[((3, 4, 5), )])
|
||||
self.assertEqual(scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)))
|
||||
scripted_fn_int_tuple_input = torch.jit.script(
|
||||
test_list_and_tuple, example_inputs=[((3, 4, 5),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3))
|
||||
)
|
||||
|
||||
def test_nested_list_and_tuple(self):
|
||||
def test_nested_list(inp):
|
||||
@ -282,43 +471,207 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(test_nested_list, test_nested_tuple)
|
||||
|
||||
list_inp = [[1, 2, 3, ], [5, 6, 7, ]]
|
||||
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
|
||||
inp = [[0, 4, 7, ], [8, 11, ], [6, -1, -20, ]]
|
||||
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
||||
list_inp = [
|
||||
[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
[
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
],
|
||||
]
|
||||
scripted_fn = torch.jit.script(
|
||||
test_nested_list,
|
||||
example_inputs=[
|
||||
(list_inp,),
|
||||
],
|
||||
)
|
||||
inp = [
|
||||
[
|
||||
0,
|
||||
4,
|
||||
7,
|
||||
],
|
||||
[
|
||||
8,
|
||||
11,
|
||||
],
|
||||
[
|
||||
6,
|
||||
-1,
|
||||
-20,
|
||||
],
|
||||
]
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
inp,
|
||||
),
|
||||
test_nested_list(
|
||||
inp,
|
||||
),
|
||||
)
|
||||
|
||||
list_inp = ([1, 2, 3, ], [5, 6, 7, ])
|
||||
scripted_fn = torch.jit.script(test_nested_list, example_inputs=[(list_inp, ), ])
|
||||
inp = ([0, 4, 7, ], [8, 11, ], [6, -1, -20, ])
|
||||
self.assertEqual(scripted_fn(inp, ), test_nested_list(inp, ))
|
||||
list_inp = (
|
||||
[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
],
|
||||
[
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
],
|
||||
)
|
||||
scripted_fn = torch.jit.script(
|
||||
test_nested_list,
|
||||
example_inputs=[
|
||||
(list_inp,),
|
||||
],
|
||||
)
|
||||
inp = (
|
||||
[
|
||||
0,
|
||||
4,
|
||||
7,
|
||||
],
|
||||
[
|
||||
8,
|
||||
11,
|
||||
],
|
||||
[
|
||||
6,
|
||||
-1,
|
||||
-20,
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
inp,
|
||||
),
|
||||
test_nested_list(
|
||||
inp,
|
||||
),
|
||||
)
|
||||
|
||||
tup_inp = [(1.0, 2.6, 3.7, ), (5.7, 6.1, 1.7, )]
|
||||
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
||||
inp = [(1.0, 4.1, 7.4, ), (4.8, 1.1, -1.2, ), (6.3, -1.3, -2.0, )]
|
||||
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
||||
tup_inp = [
|
||||
(
|
||||
1.0,
|
||||
2.6,
|
||||
3.7,
|
||||
),
|
||||
(
|
||||
5.7,
|
||||
6.1,
|
||||
1.7,
|
||||
),
|
||||
]
|
||||
scripted_fn = torch.jit.script(
|
||||
test_nested_tuple,
|
||||
example_inputs=[
|
||||
(tup_inp,),
|
||||
],
|
||||
)
|
||||
inp = [
|
||||
(
|
||||
1.0,
|
||||
4.1,
|
||||
7.4,
|
||||
),
|
||||
(
|
||||
4.8,
|
||||
1.1,
|
||||
-1.2,
|
||||
),
|
||||
(
|
||||
6.3,
|
||||
-1.3,
|
||||
-2.0,
|
||||
),
|
||||
]
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
inp,
|
||||
),
|
||||
test_nested_tuple(
|
||||
inp,
|
||||
),
|
||||
)
|
||||
|
||||
tup_inp = ((True, False, True, ), (False, False, False, ))
|
||||
scripted_fn = torch.jit.script(test_nested_tuple, example_inputs=[(tup_inp, ), ])
|
||||
inp = ((True, True, True, ), (False, False, True, ))
|
||||
self.assertEqual(scripted_fn(inp, ), test_nested_tuple(inp, ))
|
||||
tup_inp = (
|
||||
(
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
),
|
||||
)
|
||||
scripted_fn = torch.jit.script(
|
||||
test_nested_tuple,
|
||||
example_inputs=[
|
||||
(tup_inp,),
|
||||
],
|
||||
)
|
||||
inp = (
|
||||
(
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
inp,
|
||||
),
|
||||
test_nested_tuple(
|
||||
inp,
|
||||
),
|
||||
)
|
||||
|
||||
def test_pdt_dict(self):
|
||||
def test_dict(a):
|
||||
return a['foo']
|
||||
return a["foo"]
|
||||
|
||||
def test_dict_int_list(a):
|
||||
return a[1]
|
||||
|
||||
make_global(test_dict, test_dict_int_list)
|
||||
|
||||
str_bool_inp = {'foo' : True, 'bar': False}
|
||||
str_bool_inp = {"foo": True, "bar": False}
|
||||
scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)])
|
||||
self.assertEqual(scripted_fn({'foo' : False, 'bar': True}, ), test_dict({'foo' : False, 'bar': True}, ))
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
{"foo": False, "bar": True},
|
||||
),
|
||||
test_dict(
|
||||
{"foo": False, "bar": True},
|
||||
),
|
||||
)
|
||||
|
||||
str_list_inp = {0 : [True, False], 1: [False, True]}
|
||||
scripted_fn = torch.jit.script(test_dict_int_list, example_inputs=[(str_list_inp,)])
|
||||
self.assertEqual(scripted_fn({0 : [False, False], 1: [True, True]}, ),
|
||||
test_dict_int_list({0 : [False, False], 1: [True, True]}, ))
|
||||
str_list_inp = {0: [True, False], 1: [False, True]}
|
||||
scripted_fn = torch.jit.script(
|
||||
test_dict_int_list, example_inputs=[(str_list_inp,)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
{0: [False, False], 1: [True, True]},
|
||||
),
|
||||
test_dict_int_list(
|
||||
{0: [False, False], 1: [True, True]},
|
||||
),
|
||||
)
|
||||
|
||||
def test_any(self):
|
||||
def test_multiple_types(a):
|
||||
@ -337,20 +690,36 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(test_multiple_types, test_multiple_type_refinement)
|
||||
|
||||
scripted_fn = torch.jit.script(test_multiple_types, example_inputs=[(1,), ("abc", ), (8.9,), ([3, 4, 5], )])
|
||||
scripted_fn = torch.jit.script(
|
||||
test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)]
|
||||
)
|
||||
self.assertEqual(scripted_fn(10), test_multiple_types(10))
|
||||
self.assertEqual(scripted_fn("def"), test_multiple_types("def"))
|
||||
self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999))
|
||||
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14]))
|
||||
|
||||
scripted_fn = torch.jit.script(test_multiple_type_refinement, example_inputs=[(1,), ("abc", ), (8.9,),
|
||||
([3, 4, 5],), (True, ), ({"a": True}, ), ])
|
||||
scripted_fn = torch.jit.script(
|
||||
test_multiple_type_refinement,
|
||||
example_inputs=[
|
||||
(1,),
|
||||
("abc",),
|
||||
(8.9,),
|
||||
([3, 4, 5],),
|
||||
(True,),
|
||||
({"a": True},),
|
||||
],
|
||||
)
|
||||
self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10))
|
||||
self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def"))
|
||||
self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999))
|
||||
self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14]))
|
||||
self.assertEqual(
|
||||
scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14])
|
||||
)
|
||||
self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False))
|
||||
self.assertEqual(scripted_fn({"abc" : True, "def": False}), test_multiple_type_refinement({"abc" : True, "def": False}))
|
||||
self.assertEqual(
|
||||
scripted_fn({"abc": True, "def": False}),
|
||||
test_multiple_type_refinement({"abc": True, "def": False}),
|
||||
)
|
||||
|
||||
def test_class_as_profiled_types(self):
|
||||
class UserDefinedClass:
|
||||
@ -369,9 +738,33 @@ class TestPDT(JitTestCase):
|
||||
make_global(UserDefinedClass, test_model)
|
||||
|
||||
user_class = UserDefinedClass()
|
||||
scripted_fn = torch.jit.script(test_model, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
||||
self.assertEqual(scripted_fn(100, user_class, ), test_model(100, user_class))
|
||||
self.assertEqual(scripted_fn(1.9, user_class, ), test_model(1.9, user_class))
|
||||
scripted_fn = torch.jit.script(
|
||||
test_model,
|
||||
example_inputs=[
|
||||
(
|
||||
10,
|
||||
user_class,
|
||||
),
|
||||
(
|
||||
10.9,
|
||||
user_class,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
100,
|
||||
user_class,
|
||||
),
|
||||
test_model(100, user_class),
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
1.9,
|
||||
user_class,
|
||||
),
|
||||
test_model(1.9, user_class),
|
||||
)
|
||||
|
||||
def test_class_with_args_as_profiled_types(self):
|
||||
class ClassWithArgs:
|
||||
@ -391,8 +784,26 @@ class TestPDT(JitTestCase):
|
||||
make_global(ClassWithArgs, test_model_with_args)
|
||||
|
||||
user_class = ClassWithArgs(False)
|
||||
scripted_fn = torch.jit.script(test_model_with_args, example_inputs=[(10, user_class, ), (10.9, user_class, ), ])
|
||||
self.assertEqual(scripted_fn(100, ClassWithArgs(True), ), test_model_with_args(100, ClassWithArgs(True)))
|
||||
scripted_fn = torch.jit.script(
|
||||
test_model_with_args,
|
||||
example_inputs=[
|
||||
(
|
||||
10,
|
||||
user_class,
|
||||
),
|
||||
(
|
||||
10.9,
|
||||
user_class,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
100,
|
||||
ClassWithArgs(True),
|
||||
),
|
||||
test_model_with_args(100, ClassWithArgs(True)),
|
||||
)
|
||||
|
||||
def test_nn_parameter_as_arg(self):
|
||||
class TestNNParameter(torch.nn.Module):
|
||||
@ -408,7 +819,14 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(TestNNParameter)
|
||||
pdt_model = TestNNParameter()
|
||||
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [(10, ), ], })
|
||||
scripted_fn = torch.jit.script(
|
||||
pdt_model,
|
||||
example_inputs={
|
||||
pdt_model: [
|
||||
(10,),
|
||||
],
|
||||
},
|
||||
)
|
||||
self.assertEqual(scripted_fn(20), pdt_model(20))
|
||||
|
||||
def test_fx_tracing_with_typing(self):
|
||||
@ -422,7 +840,19 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(FXModel, FXModelOutput)
|
||||
pdt_model = FXModel()
|
||||
scripted_fn = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })
|
||||
scripted_fn = torch.jit.script(
|
||||
pdt_model,
|
||||
example_inputs={
|
||||
pdt_model: [
|
||||
(
|
||||
[
|
||||
10,
|
||||
20,
|
||||
],
|
||||
),
|
||||
],
|
||||
},
|
||||
)
|
||||
self.assertEqual(scripted_fn([20]), pdt_model([20]))
|
||||
|
||||
def test_nonetype_as_optional_of_type(self):
|
||||
@ -434,11 +864,34 @@ class TestPDT(JitTestCase):
|
||||
|
||||
make_global(test_none)
|
||||
|
||||
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10.6, )])
|
||||
self.assertEqual(scripted_fn(30.9, ), test_none(30.9, ))
|
||||
scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)])
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
30.9,
|
||||
),
|
||||
test_none(
|
||||
30.9,
|
||||
),
|
||||
)
|
||||
|
||||
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (10, )])
|
||||
self.assertEqual(scripted_fn(2, ), test_none(2, ))
|
||||
scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)])
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
2,
|
||||
),
|
||||
test_none(
|
||||
2,
|
||||
),
|
||||
)
|
||||
|
||||
scripted_fn = torch.jit.script(test_none, example_inputs=[(None, ), (torch.Tensor(1), )])
|
||||
self.assertEqual(scripted_fn(torch.ones(1), ), test_none(torch.ones(1), ))
|
||||
scripted_fn = torch.jit.script(
|
||||
test_none, example_inputs=[(None,), (torch.Tensor(1),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
scripted_fn(
|
||||
torch.ones(1),
|
||||
),
|
||||
test_none(
|
||||
torch.ones(1),
|
||||
),
|
||||
)
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
|
||||
from torch import nn
|
||||
from torch.testing import FileCheck
|
||||
import unittest
|
||||
from typing import Callable, List
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestPeephole(JitTestCase):
|
||||
def test_peephole_with_writes(self):
|
||||
@ -62,11 +65,11 @@ class TestPeephole(JitTestCase):
|
||||
|
||||
tf = torch.jit.trace(f, (a, b))
|
||||
FileCheck().check("type_as").run(str(tf.graph))
|
||||
self.run_pass('peephole', tf.graph)
|
||||
self.run_pass("peephole", tf.graph)
|
||||
FileCheck().check_not("type_as").run(str(tf.graph))
|
||||
tf2 = torch.jit.trace(f, (a, c))
|
||||
s = str(tf2.graph)
|
||||
self.run_pass('peephole', tf2.graph)
|
||||
self.run_pass("peephole", tf2.graph)
|
||||
self.assertEqual(s, str(s))
|
||||
|
||||
def test_peephole_dynamic(self):
|
||||
@ -83,7 +86,7 @@ class TestPeephole(JitTestCase):
|
||||
def foo(x, y, z):
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("value=3").check_next("return").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
@ -93,7 +96,7 @@ class TestPeephole(JitTestCase):
|
||||
li.append(x)
|
||||
return len([x, y, z])
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check_not("aten::len").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
@ -102,7 +105,7 @@ class TestPeephole(JitTestCase):
|
||||
return li[1], li[-2]
|
||||
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
self.run_pass('peephole', foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check_not("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
@ -110,7 +113,7 @@ class TestPeephole(JitTestCase):
|
||||
li = [x, y, z]
|
||||
return li[-7]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
@ -120,25 +123,25 @@ class TestPeephole(JitTestCase):
|
||||
li.append(x)
|
||||
return li[-2]
|
||||
|
||||
self.run_pass('peephole', foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("aten::__getitem__").run(foo.graph)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
||||
def test_peephole_cuda(self):
|
||||
a = torch.tensor([0.4], device='cpu')
|
||||
b = torch.tensor([0.7], device='cuda')
|
||||
c = torch.tensor([0.7], device='cuda')
|
||||
a = torch.tensor([0.4], device="cpu")
|
||||
b = torch.tensor([0.7], device="cuda")
|
||||
c = torch.tensor([0.7], device="cuda")
|
||||
|
||||
def f(x, y):
|
||||
return x.type_as(y)
|
||||
|
||||
trace = torch.jit.trace(f, (a, c))
|
||||
s = str(trace.graph)
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.run_pass("peephole", trace.graph)
|
||||
self.assertEqual(s, str(trace.graph))
|
||||
trace = torch.jit.trace(f, (b, c))
|
||||
self.run_pass('peephole', trace.graph)
|
||||
self.run_pass('dce', trace.graph)
|
||||
self.run_pass("peephole", trace.graph)
|
||||
self.run_pass("dce", trace.graph)
|
||||
FileCheck().check_not("type_as").run(str(trace.graph))
|
||||
|
||||
@_inline_everything
|
||||
@ -152,7 +155,7 @@ class TestPeephole(JitTestCase):
|
||||
return refine(torch.tensor(4))
|
||||
|
||||
FileCheck().check("prim::unchecked_cast").run(test.graph)
|
||||
self.run_pass('peephole', test.graph)
|
||||
self.run_pass("peephole", test.graph)
|
||||
FileCheck().check_not("prim::unchecked_cast").run(test.graph)
|
||||
|
||||
# refinement not optimzied out
|
||||
@ -166,7 +169,7 @@ class TestPeephole(JitTestCase):
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2),))
|
||||
self.checkScript(is_int_tensor, (torch.tensor(2.5),))
|
||||
graph = torch.jit.script(is_int_tensor).graph
|
||||
self.run_pass('peephole', graph)
|
||||
self.run_pass("peephole", graph)
|
||||
FileCheck().check("prim::unchecked_cast").run(graph)
|
||||
|
||||
def test_short_circuit_optimization(self):
|
||||
@ -174,8 +177,11 @@ class TestPeephole(JitTestCase):
|
||||
def const_expressions(x):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and False, x == 1 or True
|
||||
self.run_pass('constant_propagation', const_expressions.graph)
|
||||
FileCheck().check_not("prim::If").check_not("aten::eq").run(const_expressions.graph)
|
||||
|
||||
self.run_pass("constant_propagation", const_expressions.graph)
|
||||
FileCheck().check_not("prim::If").check_not("aten::eq").run(
|
||||
const_expressions.graph
|
||||
)
|
||||
self.assertEqual(const_expressions(1), (False, True))
|
||||
|
||||
@torch.jit.script
|
||||
@ -183,15 +189,18 @@ class TestPeephole(JitTestCase):
|
||||
# type: (int) -> Tuple[bool, bool]
|
||||
return x == 1 and True, x == 1 or False
|
||||
|
||||
self.run_pass('peephole', redundant_expressions.graph)
|
||||
self.run_pass("peephole", redundant_expressions.graph)
|
||||
self.assertEqual(redundant_expressions(1), (True, True))
|
||||
self.assertEqual(redundant_expressions(0), (False, False))
|
||||
# and True / or False are removed from graph
|
||||
FileCheck().check("aten::eq").check_not("prim::If").run(redundant_expressions.graph)
|
||||
FileCheck().check("aten::eq").check_not("prim::If").run(
|
||||
redundant_expressions.graph
|
||||
)
|
||||
|
||||
def test_conv_dim_folding(self):
|
||||
modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
|
||||
for mod in modules:
|
||||
|
||||
class ConvDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -233,7 +242,6 @@ class TestPeephole(JitTestCase):
|
||||
FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph)
|
||||
FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph)
|
||||
|
||||
|
||||
def test_normalized_is_op(self):
|
||||
def convertible_is_op(x: bool, y: bool):
|
||||
return x is True, False is x, x is y
|
||||
@ -558,7 +566,7 @@ class TestPeephole(JitTestCase):
|
||||
|
||||
def foo4():
|
||||
x = torch.zeros([2, 2])
|
||||
return x + 0.
|
||||
return x + 0.0
|
||||
|
||||
funcs = foo1, foo2, foo3, foo4
|
||||
inps = (torch.ones([2]),), (), (), ()
|
||||
@ -582,7 +590,7 @@ class TestPeephole(JitTestCase):
|
||||
self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2])))
|
||||
|
||||
def func(x):
|
||||
return (x + 0.) - 5
|
||||
return (x + 0.0) - 5
|
||||
|
||||
func_s = torch.jit.script(func)
|
||||
inp = next(func_s.graph.inputs())
|
||||
@ -640,6 +648,7 @@ class TestPeephole(JitTestCase):
|
||||
return z
|
||||
else:
|
||||
return z2
|
||||
|
||||
out = next(foo.graph.findNode("prim::If").outputs())
|
||||
out.setType(torch._C.OptionalType(torch._C.IntType.get()))
|
||||
self.run_pass("peephole", foo.graph)
|
||||
@ -665,12 +674,13 @@ class TestPeephole(JitTestCase):
|
||||
_6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1
|
||||
return [_5, int(_6)]
|
||||
|
||||
FileCheck().check("aten::add").check("aten::sub") \
|
||||
.check("aten::mul").check("aten::floordiv") \
|
||||
.check("aten::div").run(foo.graph)
|
||||
FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check(
|
||||
"aten::floordiv"
|
||||
).check("aten::div").run(foo.graph)
|
||||
self.run_pass("peephole", foo.graph)
|
||||
FileCheck().check("graph").check("):") \
|
||||
.check_next("ListConstruct").check_next("return").run(foo.graph)
|
||||
FileCheck().check("graph").check("):").check_next("ListConstruct").check_next(
|
||||
"return"
|
||||
).run(foo.graph)
|
||||
self.assertEqual(foo(0, 1, 2, 3), [1, 3])
|
||||
|
||||
def test_peephole_dict_getitem_simple(self):
|
||||
@ -687,9 +697,9 @@ class TestPeephole(JitTestCase):
|
||||
|
||||
@torch.jit.script
|
||||
def foo(a: int, b: int):
|
||||
d = {'0': a, '1': b}
|
||||
x = d['1']
|
||||
y = d['0']
|
||||
d = {"0": a, "1": b}
|
||||
x = d["1"]
|
||||
y = d["0"]
|
||||
return x, y
|
||||
|
||||
self.run_pass("peephole", foo.graph)
|
||||
@ -815,14 +825,14 @@ class TestPeephole(JitTestCase):
|
||||
graph = torch.jit.script(foo).graph
|
||||
self.run_pass("peephole", graph)
|
||||
FileCheck().check_not("aten::slice").run(graph)
|
||||
self.checkScript(foo, (3, ))
|
||||
self.checkScript(foo, (3,))
|
||||
|
||||
def test_peephole_slice_one_empty_arg(self):
|
||||
def check_helper(fn: Callable[[int], None]) -> None:
|
||||
graph = torch.jit.script(fn).graph
|
||||
self.run_pass("peephole", graph)
|
||||
FileCheck().check_not("aten::slice").run(graph)
|
||||
self.checkScript(fn, (3, ))
|
||||
self.checkScript(fn, (3,))
|
||||
|
||||
def foo(x: int):
|
||||
return [1, 2, x, 4, 5, 6, 7][1::2]
|
||||
@ -844,7 +854,7 @@ class TestPeephole(JitTestCase):
|
||||
graph = torch.jit.script(fn).graph
|
||||
self.run_pass("peephole", graph)
|
||||
FileCheck().check_not("aten::slice").run(graph)
|
||||
self.checkScript(fn, (3, ))
|
||||
self.checkScript(fn, (3,))
|
||||
|
||||
def foo(x: int):
|
||||
return [1, 2, x, 4, 5, 6, 7][::2]
|
||||
|
||||
@ -9,12 +9,15 @@ from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, warmup_backward, FileCheck
|
||||
from torch.testing._internal.jit_utils import FileCheck, JitTestCase, warmup_backward
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
class TestProfiler(JitTestCase):
|
||||
@ -58,8 +61,9 @@ class TestProfiler(JitTestCase):
|
||||
|
||||
# item & add should not get pulled into the fusion group -
|
||||
# we expect to see Fusion Group (item / add) Fusion Group in ir dump
|
||||
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next("Tensor = aten::add").check("TensorExpr").run(g)
|
||||
|
||||
FileCheck().check("TensorExpr").check("Scalar = aten::item").check_next(
|
||||
"Tensor = aten::add"
|
||||
).check("TensorExpr").run(g)
|
||||
|
||||
@torch.jit.script
|
||||
def non_const_dtype(x, y, cond: bool):
|
||||
@ -70,7 +74,9 @@ class TestProfiler(JitTestCase):
|
||||
non_const_dtype(x, x, True)
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
# because dtype is non-const, sum should not get pulled into the Fusion Group
|
||||
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(g)
|
||||
FileCheck().check("TensorExpr").check("TensorExpr").check_not("aten::sum").run(
|
||||
g
|
||||
)
|
||||
|
||||
def test_specialize_backward(self):
|
||||
def test_fuse(a, b):
|
||||
@ -118,13 +124,15 @@ class TestProfiler(JitTestCase):
|
||||
d = c * b
|
||||
return d
|
||||
|
||||
x = torch.tensor([.5])
|
||||
x = torch.tensor([0.5])
|
||||
for _ in range(3):
|
||||
test_fuse(x, x)
|
||||
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
# Types should remain specialized for typecheck outputs & fusion outputs
|
||||
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same("\n").check("Double").check_same("TensorExpr").run(g)
|
||||
FileCheck().check("Double(").check_same("prim::TypeCheck").check_same(
|
||||
"\n"
|
||||
).check("Double").check_same("TensorExpr").run(g)
|
||||
|
||||
# other outputs should not be specialized
|
||||
FileCheck().check("Tensor = prim::If").run(g)
|
||||
@ -201,7 +209,9 @@ class TestProfiler(JitTestCase):
|
||||
foo(x, y)
|
||||
foo(x, y)
|
||||
g = torch.jit.last_executed_optimized_graph()
|
||||
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(g)
|
||||
FileCheck().check("CallFunction").check_next("Tensor = prim::TupleUnpack").run(
|
||||
g
|
||||
)
|
||||
|
||||
def test_autograd_fallback_graph(self):
|
||||
@torch.jit.script
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import random
|
||||
from textwrap import dedent
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -20,14 +20,17 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
def get_fn(file_name, script_path):
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(file_name, script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
fn = module.fn
|
||||
return fn
|
||||
|
||||
|
||||
class TestPythonBuiltinOP(JitTestCase):
|
||||
def test_add(self):
|
||||
def func(a, b):
|
||||
@ -48,16 +51,18 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
self.checkScript(func, (a, b), optimize=True)
|
||||
|
||||
def test_matmul_py3(self):
|
||||
code = dedent("""
|
||||
code = dedent(
|
||||
"""
|
||||
def fn(a, b):
|
||||
return a @ b
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
script_path = os.path.join(tmp_dir, 'script.py')
|
||||
with open(script_path, 'w') as f:
|
||||
script_path = os.path.join(tmp_dir, "script.py")
|
||||
with open(script_path, "w") as f:
|
||||
f.write(code)
|
||||
fn = get_fn('test_matmul_py3', script_path)
|
||||
fn = get_fn("test_matmul_py3", script_path)
|
||||
|
||||
a = torch.rand(4, 3, requires_grad=True)
|
||||
b = torch.rand(3, 2, requires_grad=True)
|
||||
@ -65,18 +70,18 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
|
||||
def test_pow(self):
|
||||
def func(a, b):
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
def func2(a, b, c, d):
|
||||
return c + a ** b ** d
|
||||
return c + a**b**d
|
||||
|
||||
def func3(a, b):
|
||||
# type: (int, float) -> float
|
||||
return a ** b
|
||||
return a**b
|
||||
|
||||
def func4():
|
||||
# type: () -> float
|
||||
return 2 ** -2
|
||||
return 2**-2
|
||||
|
||||
def func5(x, y):
|
||||
return x.item() ** y.item()
|
||||
@ -90,7 +95,12 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
self.checkScript(func3, (4, -0.5), optimize=True)
|
||||
self.checkScript(func4, ())
|
||||
|
||||
inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)]
|
||||
inputs = [
|
||||
torch.tensor(2),
|
||||
torch.tensor(-2),
|
||||
torch.tensor(0.5),
|
||||
torch.tensor(0.2),
|
||||
]
|
||||
for x in inputs:
|
||||
for y in inputs:
|
||||
if x < 0:
|
||||
@ -100,7 +110,7 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
|
||||
def test_triple(self):
|
||||
def func(x):
|
||||
return 3. * x
|
||||
return 3.0 * x
|
||||
|
||||
x = torch.rand(1, dtype=torch.float, requires_grad=True)
|
||||
self.checkScript(func, [x], optimize=True)
|
||||
@ -154,22 +164,36 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
|
||||
def test_stepped_tuple_slicing(self):
|
||||
def check_slicing_tuple(slicing, tuple_type, tuple):
|
||||
template = dedent("""
|
||||
template = dedent(
|
||||
"""
|
||||
def func(x):
|
||||
# type: ({}) -> Any
|
||||
return x{}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
self._check_code(template.format(tuple_type, slicing), "func", [tuple])
|
||||
|
||||
check_slicing_tuple("[-3:3:2]", "Tuple[int, int, int]", (0, 1, 2))
|
||||
check_slicing_tuple("[::55]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
|
||||
check_slicing_tuple("[:4:4]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
|
||||
check_slicing_tuple("[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
|
||||
check_slicing_tuple("[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
|
||||
check_slicing_tuple("[5:7:-2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6))
|
||||
check_slicing_tuple(
|
||||
"[::-1]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
|
||||
)
|
||||
check_slicing_tuple(
|
||||
"[7:5:2]", "Tuple[int, int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5, 6)
|
||||
)
|
||||
check_slicing_tuple(
|
||||
"[5:7:-2]",
|
||||
"Tuple[int, int, int, int, int, int, int]",
|
||||
(0, 1, 2, 3, 4, 5, 6),
|
||||
)
|
||||
check_slicing_tuple("[::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
|
||||
check_slicing_tuple("[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5))
|
||||
check_slicing_tuple("[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4))
|
||||
check_slicing_tuple(
|
||||
"[:4:-3]", "Tuple[int, int, int, int, int, int]", (0, 1, 2, 3, 4, 5)
|
||||
)
|
||||
check_slicing_tuple(
|
||||
"[3::-2]", "Tuple[int, int, int, int, int]", (0, 1, 2, 3, 4)
|
||||
)
|
||||
|
||||
def test_index(self):
|
||||
def consec(size, start=0):
|
||||
@ -177,10 +201,12 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
return torch.arange(numel).view(size)
|
||||
|
||||
def check_indexing(indexing, tensor):
|
||||
template = dedent("""
|
||||
template = dedent(
|
||||
"""
|
||||
def func(x):
|
||||
return x{}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
self._check_code(template.format(indexing), "func", [tensor])
|
||||
|
||||
@ -188,62 +214,66 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
value1 = torch.tensor(value1)
|
||||
value2 = torch.tensor(value2)
|
||||
|
||||
template = dedent("""
|
||||
template = dedent(
|
||||
"""
|
||||
def func(x, value1, value2):
|
||||
i = int(value1)
|
||||
j = int(value2)
|
||||
return x{}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
|
||||
self._check_code(
|
||||
template.format(indexing), "func", [tensor, value1, value2]
|
||||
)
|
||||
|
||||
# basic slices
|
||||
check_indexing('[0]', consec((3, 3)))
|
||||
check_indexing('[1]', consec((3, 3), 10))
|
||||
check_indexing('[2]', consec((3, 3), 19))
|
||||
check_indexing('[2]', consec((3,)))
|
||||
check_indexing('[-1]', consec((3, 3), 19))
|
||||
check_indexing('[0:2]', consec((3, 3, 3)))
|
||||
check_indexing('[1:-1]', consec((3, 3, 3)))
|
||||
check_indexing('[-3:-1]', consec((6, 3)))
|
||||
check_indexing('[1:]', consec((3, 3)))
|
||||
check_indexing('[:1]', consec((3, 3)))
|
||||
check_indexing('[:]', consec((3, 2)))
|
||||
check_indexing("[0]", consec((3, 3)))
|
||||
check_indexing("[1]", consec((3, 3), 10))
|
||||
check_indexing("[2]", consec((3, 3), 19))
|
||||
check_indexing("[2]", consec((3,)))
|
||||
check_indexing("[-1]", consec((3, 3), 19))
|
||||
check_indexing("[0:2]", consec((3, 3, 3)))
|
||||
check_indexing("[1:-1]", consec((3, 3, 3)))
|
||||
check_indexing("[-3:-1]", consec((6, 3)))
|
||||
check_indexing("[1:]", consec((3, 3)))
|
||||
check_indexing("[:1]", consec((3, 3)))
|
||||
check_indexing("[:]", consec((3, 2)))
|
||||
|
||||
# multi-dim: indexes
|
||||
check_indexing('[0, 1]', consec((3, 3)))
|
||||
check_indexing('[0, 1]', consec((3, 3, 2)))
|
||||
check_indexing('[1, 0, 2]', consec((3, 3, 3)))
|
||||
check_indexing('[2, -1]', consec((3, 3)))
|
||||
check_indexing("[0, 1]", consec((3, 3)))
|
||||
check_indexing("[0, 1]", consec((3, 3, 2)))
|
||||
check_indexing("[1, 0, 2]", consec((3, 3, 3)))
|
||||
check_indexing("[2, -1]", consec((3, 3)))
|
||||
|
||||
# multi-dim: mixed slicing and indexing
|
||||
check_indexing('[0, 1:2]', consec((3, 3)))
|
||||
check_indexing('[0, :1]', consec((3, 3, 2)))
|
||||
check_indexing('[1, 2:]', consec((3, 3, 3)))
|
||||
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
|
||||
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
|
||||
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
|
||||
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
|
||||
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
|
||||
check_indexing("[0, 1:2]", consec((3, 3)))
|
||||
check_indexing("[0, :1]", consec((3, 3, 2)))
|
||||
check_indexing("[1, 2:]", consec((3, 3, 3)))
|
||||
check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
|
||||
check_indexing("[1:, -1, 0]", consec((3, 3, 3, 3)))
|
||||
check_indexing("[-1, 2:, 1:2]", consec((3, 3, 3, 3)))
|
||||
check_indexing("[-1, 1:, 0]", consec((3, 3, 3, 3)))
|
||||
check_indexing("[-1, :, 0, 2]", consec((3, 3, 3, 3)))
|
||||
|
||||
# zero-sized slices
|
||||
check_indexing('[0:0]', consec((2, 2)))
|
||||
check_indexing('[0:0, 1]', consec((3, 3)))
|
||||
check_indexing("[0:0]", consec((2, 2)))
|
||||
check_indexing("[0:0, 1]", consec((3, 3)))
|
||||
|
||||
# trivial expression usage
|
||||
check_indexing('[1+1]', consec((3, 3)))
|
||||
check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
|
||||
check_indexing("[1+1]", consec((3, 3)))
|
||||
check_indexing("[1:(0 + 2)]", consec((3, 3, 3)))
|
||||
|
||||
# None for new dimensions
|
||||
check_indexing('[None, 0]', consec((3, 3)))
|
||||
check_indexing('[1, None]', consec((3, 3), 10))
|
||||
check_indexing('[None, None, 2]', consec((3, 3), 19))
|
||||
check_indexing('[None, 2, None]', consec((3,)))
|
||||
check_indexing('[0:2, None]', consec((3, 3, 3)))
|
||||
check_indexing('[None, 1:-1]', consec((3, 3, 3)))
|
||||
check_indexing('[None, -3:-1, None]', consec((6, 3)))
|
||||
check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3)))
|
||||
check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3)))
|
||||
check_indexing("[None, 0]", consec((3, 3)))
|
||||
check_indexing("[1, None]", consec((3, 3), 10))
|
||||
check_indexing("[None, None, 2]", consec((3, 3), 19))
|
||||
check_indexing("[None, 2, None]", consec((3,)))
|
||||
check_indexing("[0:2, None]", consec((3, 3, 3)))
|
||||
check_indexing("[None, 1:-1]", consec((3, 3, 3)))
|
||||
check_indexing("[None, -3:-1, None]", consec((6, 3)))
|
||||
check_indexing("[-1, None, 2:, None, 1:2]", consec((3, 3, 3, 3)))
|
||||
check_indexing("[None, -1, None, 2:, None, 1:2, None]", consec((3, 3, 3, 3)))
|
||||
|
||||
# dynamic expression usage
|
||||
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
|
||||
@ -257,10 +287,12 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
def check_indexing(indexing, tensor, **kwargs):
|
||||
indices_dict = kwargs
|
||||
|
||||
template = dedent("""
|
||||
template = dedent(
|
||||
"""
|
||||
def func(x{formals}):
|
||||
return x{expr}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
formals = []
|
||||
values = []
|
||||
@ -268,17 +300,18 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
formals.append(formal)
|
||||
values.append(value)
|
||||
|
||||
formals = ''.join(map(', {}'.format, formals))
|
||||
formals = "".join(map(", {}".format, formals))
|
||||
inputs = [tensor] + values
|
||||
self._check_code(template.format(formals=formals, expr=indexing),
|
||||
"func", inputs)
|
||||
self._check_code(
|
||||
template.format(formals=formals, expr=indexing), "func", inputs
|
||||
)
|
||||
|
||||
# Indexing with tensor (basic)
|
||||
check_indexing('[i]', consec((3, 3)), i=torch.tensor([0]))
|
||||
check_indexing('[i]', consec((3, 3)), i=torch.tensor(1))
|
||||
check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2]))
|
||||
check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0]))
|
||||
check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
|
||||
check_indexing("[i]", consec((3, 3)), i=torch.tensor([0]))
|
||||
check_indexing("[i]", consec((3, 3)), i=torch.tensor(1))
|
||||
check_indexing("[i]", consec((3, 3)), i=torch.tensor([-2]))
|
||||
check_indexing("[i]", consec((3, 3), 2), i=torch.tensor([0, 0]))
|
||||
check_indexing("[i]", consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
|
||||
|
||||
# NB: indexing with tensors and indexing with sequences can be implemented
|
||||
# in a very similar way (sequences are converted to tensors), so only one
|
||||
@ -290,49 +323,49 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
inp = consec((4, 8, 5))
|
||||
to_check = [
|
||||
# [[0, 1, 3]]
|
||||
['[i]', {'i': [0, 1, 3]}],
|
||||
["[i]", {"i": [0, 1, 3]}],
|
||||
# [[0, 2], [1, 3]]
|
||||
['[i, j]', {'i': [0, 2], 'j': [1, 3]}],
|
||||
["[i, j]", {"i": [0, 2], "j": [1, 3]}],
|
||||
# [[[0, 1], [0, 1]], [[0, 1], [0, 1]]]
|
||||
['[i, j]', {'i': [[0, 1], [0, 1]], 'j': [[0, 1], [0, 1]]}],
|
||||
["[i, j]", {"i": [[0, 1], [0, 1]], "j": [[0, 1], [0, 1]]}],
|
||||
# [[0, 2], [1, 3], [1, 1]]
|
||||
['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}],
|
||||
["[i, j, k]", {"i": [0, 2], "j": [1, 3], "k": [1, 1]}],
|
||||
# [[0, 2], 1, [1, 1]]
|
||||
['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}],
|
||||
["[i, j, k]", {"i": [0, 2], "j": 1, "k": [1, 1]}],
|
||||
# [:, :, [0, 3, 4]]
|
||||
['[:, :, i]', {'i': [0, 3, 4]}],
|
||||
["[:, :, i]", {"i": [0, 3, 4]}],
|
||||
# [:, [2, 4, 5, 7], 2:4]
|
||||
['[:, i, 2:4]', {'i': [0, 2, 3]}],
|
||||
["[:, i, 2:4]", {"i": [0, 2, 3]}],
|
||||
# [[2, 3], :, :]
|
||||
['[i, :, :]', {'i': [2, 3]}],
|
||||
["[i, :, :]", {"i": [2, 3]}],
|
||||
# [:, [0, 2, 3], [1, 3, 4]]
|
||||
['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
|
||||
["[:, i, j]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
|
||||
# [:, [0], [1, 2, 4]]
|
||||
['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}],
|
||||
["[:, i, j]", {"i": [0], "j": [1, 2, 4]}],
|
||||
# [:, [0, 1, 3], [4]]
|
||||
['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}],
|
||||
["[:, i, j]", {"i": [0, 1, 3], "j": [4]}],
|
||||
# [:, [[0, 1], [1, 0]], [[2, 3]]]
|
||||
['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
|
||||
["[:, i, j]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
|
||||
# [:, [[0, 1], [2, 3]], [[0]]]
|
||||
['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
|
||||
["[:, i, j]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
|
||||
# [:, [[5, 6]], [[0, 3], [4, 4]]]
|
||||
['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}],
|
||||
["[:, i, j]", {"i": [[5, 6]], "j": [[0, 3], [4, 4]]}],
|
||||
# [[0, 2, 3], [1, 3, 4], :]
|
||||
['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
|
||||
["[i, j, :]", {"i": [0, 2, 3], "j": [1, 3, 4]}],
|
||||
# [0, [1, 2, 4], :]
|
||||
['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}],
|
||||
["[i, j, :]", {"i": 0, "j": [1, 2, 4]}],
|
||||
# [[0, 1, 3], 4, :]
|
||||
['[i, j, :]', {'i': [0, 1, 3], 'j': 4}],
|
||||
["[i, j, :]", {"i": [0, 1, 3], "j": 4}],
|
||||
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
|
||||
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}],
|
||||
["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 1], [3, 5]]}],
|
||||
# [[[0, 1], [1, 0]], [[2, 3]], :]
|
||||
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
|
||||
["[i, j, :]", {"i": [[0, 1], [1, 0]], "j": [[2, 3]]}],
|
||||
# [[[0, 1], [2, 3]], [[0]], :]
|
||||
['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
|
||||
["[i, j, :]", {"i": [[0, 1], [2, 3]], "j": [[0]]}],
|
||||
# [[[2, 1]], [[0, 3], [4, 4]], :]
|
||||
['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}],
|
||||
["[i, j, :]", {"i": [[2, 1]], "j": [[0, 3], [4, 4]]}],
|
||||
# [[[2]], [[0, 3], [4, 1]], 0:2]
|
||||
['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}],
|
||||
["[i, j, 0:2]", {"i": [[2]], "j": [[0, 3], [4, 1]]}],
|
||||
]
|
||||
|
||||
for expr, argdict in to_check:
|
||||
@ -372,29 +405,35 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
for _ in range(100):
|
||||
indices = [random.choice(vals) for _ in range(4)]
|
||||
indices[random.randint(0, len(indices) - 1)] = "..."
|
||||
test_str = dedent("""
|
||||
test_str = dedent(
|
||||
"""
|
||||
def f():
|
||||
x = torch.ones(10, 9, 8, 7, 6)
|
||||
return x{indices}.shape
|
||||
""".format(indices=indices))
|
||||
test_str = test_str.replace(r"'", r'')
|
||||
""".format(
|
||||
indices=indices
|
||||
)
|
||||
)
|
||||
test_str = test_str.replace(r"'", r"")
|
||||
scope = {}
|
||||
execWrapper(test_str, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(test_str)
|
||||
res1 = cu.f()
|
||||
res2 = scope['f']()
|
||||
res2 = scope["f"]()
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
def test_inf(self):
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
return a < float('inf')
|
||||
return a < float("inf")
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
@torch.jit.script
|
||||
def bar(a):
|
||||
return a > float('-inf')
|
||||
return a > float("-inf")
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
@ -414,19 +453,22 @@ class TestPythonBuiltinOP(JitTestCase):
|
||||
def test_str_to_float(self):
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
return 0.5 == float('0.5 hello')
|
||||
return 0.5 == float("0.5 hello")
|
||||
|
||||
s = torch.rand(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "could not convert string to float"):
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
return 0.5 == float('0.5')
|
||||
return 0.5 == float("0.5")
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
@torch.jit.script
|
||||
def foo(a):
|
||||
return 0. == float('0')
|
||||
return 0.0 == float("0")
|
||||
|
||||
s = torch.rand(1)
|
||||
self.assertTrue(foo(s))
|
||||
|
||||
@ -1,22 +1,26 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestPythonIr(JitTestCase):
|
||||
def test_param_strides(self):
|
||||
def trace_me(arg):
|
||||
return arg
|
||||
|
||||
t = torch.zeros(1, 3, 16, 16)
|
||||
traced = torch.jit.trace(trace_me, t)
|
||||
value = list(traced.graph.param_node().outputs())[0]
|
||||
@ -78,8 +82,12 @@ class TestPythonIr(JitTestCase):
|
||||
|
||||
g = foo.graph
|
||||
muls = g.findAllNodes("aten::mul")
|
||||
scalar_muls = filter(lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls)
|
||||
mul_constant_int = filter(lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls)
|
||||
scalar_muls = filter(
|
||||
lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls
|
||||
)
|
||||
mul_constant_int = filter(
|
||||
lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls
|
||||
)
|
||||
for mul in mul_constant_int:
|
||||
with g.insert_point_guard(mul):
|
||||
outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs()))
|
||||
|
||||
@ -5,25 +5,31 @@ import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
import typing_extensions
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit.frontend
|
||||
import torch.nn as nn
|
||||
import typing_extensions
|
||||
from torch import Tensor
|
||||
from torch.testing import FileCheck
|
||||
from collections import OrderedDict
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything
|
||||
from torch.testing._internal.jit_utils import (
|
||||
_tmp_donotuse_dont_inline_everything,
|
||||
JitTestCase,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestRecursiveScript(JitTestCase):
|
||||
def test_inferred_nonetype(self):
|
||||
@ -87,7 +93,9 @@ class TestRecursiveScript(JitTestCase):
|
||||
return self.fn(x)
|
||||
|
||||
m = M(fn)
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to compile", "i_dont_exist"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "failed to compile", "i_dont_exist"
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
def test_init_error(self):
|
||||
@ -119,12 +127,12 @@ class TestRecursiveScript(JitTestCase):
|
||||
|
||||
# sm1 was created while m had training = True
|
||||
self.assertTrue(sm1.training)
|
||||
self.assertEqual(sm1.training, sm1._c.getattr('training'))
|
||||
self.assertEqual(sm1.training, sm1._c.getattr("training"))
|
||||
self.assertEqual(sm1(), 2)
|
||||
|
||||
# sm2 was created after m was eval'ed
|
||||
self.assertFalse(sm2.training)
|
||||
self.assertEqual(sm2.training, sm2._c.getattr('training'))
|
||||
self.assertEqual(sm2.training, sm2._c.getattr("training"))
|
||||
self.assertEqual(sm2(), 0)
|
||||
|
||||
def test_module_name(self):
|
||||
@ -165,7 +173,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
|
||||
def test_constants_with_final(self):
|
||||
class M1(torch.nn.Module):
|
||||
x : torch.jit.Final[int]
|
||||
x: torch.jit.Final[int]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -177,7 +185,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.checkModule(M1(), (torch.randn(2, 2),))
|
||||
|
||||
class M2(torch.nn.Module):
|
||||
x : typing_extensions.Final[int]
|
||||
x: typing_extensions.Final[int]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -189,7 +197,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.checkModule(M2(), (torch.randn(2, 2),))
|
||||
|
||||
class M3(torch.nn.Module):
|
||||
x : typing.Final[int]
|
||||
x: typing.Final[int]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -206,12 +214,15 @@ class TestRecursiveScript(JitTestCase):
|
||||
def unscriptable(self):
|
||||
return "a" + 200
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return MyScriptClass()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
torch.jit.frontend.FrontendError,
|
||||
"Cannot instantiate class",
|
||||
"MyScriptClass",
|
||||
):
|
||||
t = torch.jit.script(TestModule())
|
||||
|
||||
def test_method_call(self):
|
||||
@ -246,13 +257,13 @@ class TestRecursiveScript(JitTestCase):
|
||||
print(m)
|
||||
|
||||
f = FileCheck()
|
||||
f.check('MyModule')
|
||||
f.check('Conv2d')
|
||||
f.check('Linear')
|
||||
f.check('Submodule')
|
||||
f.check("MyModule")
|
||||
f.check("Conv2d")
|
||||
f.check("Linear")
|
||||
f.check("Submodule")
|
||||
f.run(out[0])
|
||||
|
||||
self.assertEqual(m.original_name, 'MyModule')
|
||||
self.assertEqual(m.original_name, "MyModule")
|
||||
|
||||
def test_dir(self):
|
||||
def test_module_dir(mod):
|
||||
@ -260,8 +271,17 @@ class TestRecursiveScript(JitTestCase):
|
||||
scripted_mod = torch.jit.script(mod)
|
||||
dir_scripted = set(dir(scripted_mod))
|
||||
# set not currently copied over
|
||||
ignore_set = ["training", "__delitem__", "__setitem__", "clear", "items",
|
||||
"keys", "pop", "update", "values"]
|
||||
ignore_set = [
|
||||
"training",
|
||||
"__delitem__",
|
||||
"__setitem__",
|
||||
"clear",
|
||||
"items",
|
||||
"keys",
|
||||
"pop",
|
||||
"update",
|
||||
"values",
|
||||
]
|
||||
for attr in dir_set:
|
||||
if attr in ignore_set:
|
||||
continue
|
||||
@ -283,7 +303,9 @@ class TestRecursiveScript(JitTestCase):
|
||||
linear = nn.Linear(10, 10)
|
||||
|
||||
test_module_dir(nn.Sequential(conv, linear))
|
||||
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
|
||||
test_module_dir(
|
||||
nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)]))
|
||||
)
|
||||
|
||||
def test_class_compile(self):
|
||||
def other_fn(a: int, b: Tensor) -> Tensor:
|
||||
@ -296,7 +318,6 @@ class TestRecursiveScript(JitTestCase):
|
||||
def helper(self, a):
|
||||
return self.x + a + other_fn(self.x, a)
|
||||
|
||||
|
||||
class N(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
b = B(x)
|
||||
@ -411,7 +432,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
|
||||
def test_module_basic(self):
|
||||
class Other(torch.nn.Module):
|
||||
__constants__ = ['x']
|
||||
__constants__ = ["x"]
|
||||
|
||||
def __init__(self, x):
|
||||
super().__init__()
|
||||
@ -426,7 +447,6 @@ class TestRecursiveScript(JitTestCase):
|
||||
def forward(self, t):
|
||||
return t + self.x + self.param
|
||||
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -439,7 +459,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
|
||||
def test_module_function_export(self):
|
||||
class Other(torch.nn.Module):
|
||||
__constants__ = ['x']
|
||||
__constants__ = ["x"]
|
||||
|
||||
def __init__(self, x):
|
||||
super().__init__()
|
||||
@ -453,7 +473,6 @@ class TestRecursiveScript(JitTestCase):
|
||||
def forward(self, t):
|
||||
return t + self.x + self.param
|
||||
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -473,9 +492,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sequential = nn.Sequential(
|
||||
Inner(),
|
||||
Inner(),
|
||||
nn.Sequential(Inner(), Inner())
|
||||
Inner(), Inner(), nn.Sequential(Inner(), Inner())
|
||||
)
|
||||
self.module_list = nn.ModuleList([Inner(), Inner()])
|
||||
|
||||
@ -511,12 +528,14 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.sequential = nn.Sequential(
|
||||
SeluButReluWhenScripted(),
|
||||
SeluButReluWhenScripted(),
|
||||
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
|
||||
nn.Sequential(
|
||||
SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()
|
||||
),
|
||||
shared,
|
||||
)
|
||||
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
|
||||
shared,
|
||||
SeluButReluWhenScripted()])
|
||||
self.module_list = nn.ModuleList(
|
||||
[SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for mod in self.module_list:
|
||||
@ -553,7 +572,8 @@ class TestRecursiveScript(JitTestCase):
|
||||
self.assertEqual(obj(1, 2), 3)
|
||||
self.assertEqual(obj(1, 2, 3, 4), 10)
|
||||
with self.assertRaisesRegex(
|
||||
torch.jit.frontend.NotSupportedError, expected_regex="can't take variable number of arguments"
|
||||
torch.jit.frontend.NotSupportedError,
|
||||
expected_regex="can't take variable number of arguments",
|
||||
):
|
||||
torch.jit.script(obj)
|
||||
|
||||
@ -568,7 +588,10 @@ class TestRecursiveScript(JitTestCase):
|
||||
|
||||
self.assertEqual(jit_obj(1, 2), 3)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, expected_regex=re.escape("expected at most 2 argument(s) but received 4 argument(s)")
|
||||
RuntimeError,
|
||||
expected_regex=re.escape(
|
||||
"expected at most 2 argument(s) but received 4 argument(s)"
|
||||
),
|
||||
):
|
||||
jit_obj(1, 2, 3, 4)
|
||||
|
||||
@ -598,27 +621,26 @@ class TestRecursiveScript(JitTestCase):
|
||||
def __getstate__(self):
|
||||
return (self.a, self.inner)
|
||||
|
||||
|
||||
untyped_values = (
|
||||
('my_dict', {"I": "am", "a test": "test"}),
|
||||
('my_float', 2.3),
|
||||
('my_int', 99),
|
||||
('my_bool', False),
|
||||
('my_tuple', (1, 2, 3, 4)),
|
||||
('my_list', [(1, 2), (3, 4)]),
|
||||
("my_dict", {"I": "am", "a test": "test"}),
|
||||
("my_float", 2.3),
|
||||
("my_int", 99),
|
||||
("my_bool", False),
|
||||
("my_tuple", (1, 2, 3, 4)),
|
||||
("my_list", [(1, 2), (3, 4)]),
|
||||
# ('my_tensor', torch.randn(2, 2)),
|
||||
('my_int_list', [1, 2, 3, 4]),
|
||||
("my_int_list", [1, 2, 3, 4]),
|
||||
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
|
||||
('my_bool_list', [True, True, False, True]),
|
||||
('my_float_list', [1., 2., 3., 4.]),
|
||||
('my_str_list', ['hello', 'bye']),
|
||||
("my_bool_list", [True, True, False, True]),
|
||||
("my_float_list", [1.0, 2.0, 3.0, 4.0]),
|
||||
("my_str_list", ["hello", "bye"]),
|
||||
)
|
||||
typed_values = (
|
||||
('my_empty_list', []),
|
||||
('my_empty_dict', {}),
|
||||
('my_none', None),
|
||||
('my_object', Foo()),
|
||||
('my_object2', SFoo()),
|
||||
("my_empty_list", []),
|
||||
("my_empty_dict", {}),
|
||||
("my_none", None),
|
||||
("my_object", Foo()),
|
||||
("my_object2", SFoo()),
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
@ -659,11 +681,11 @@ class TestRecursiveScript(JitTestCase):
|
||||
# since there's no string frontend for Python classes (so the `define`)
|
||||
# trick doesn't work.
|
||||
M.__annotations__ = {
|
||||
'my_empty_list': List[int],
|
||||
'my_empty_dict': Dict[str, int],
|
||||
'my_none': Optional[int],
|
||||
'my_object': Foo,
|
||||
'my_object2': SFoo,
|
||||
"my_empty_list": List[int],
|
||||
"my_empty_dict": Dict[str, int],
|
||||
"my_none": Optional[int],
|
||||
"my_object": Foo,
|
||||
"my_object2": SFoo,
|
||||
}
|
||||
|
||||
m = M()
|
||||
@ -694,7 +716,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
return self.encoder(x)
|
||||
|
||||
m = M()
|
||||
self.checkModule(m, (torch.randn(5, 5), ))
|
||||
self.checkModule(m, (torch.randn(5, 5),))
|
||||
|
||||
def test_inner_traced_module(self):
|
||||
class Dummy(nn.Module):
|
||||
@ -715,12 +737,13 @@ class TestRecursiveScript(JitTestCase):
|
||||
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
|
||||
dummies = nn.ModuleList([dummy])
|
||||
model = Model(dummies)
|
||||
self.checkModule(model, (torch.rand(5, 5), ))
|
||||
self.checkModule(model, (torch.rand(5, 5),))
|
||||
|
||||
def test_script_loaded_module(self):
|
||||
"""
|
||||
Test that we can hold a loaded ScriptModule as a submodule.
|
||||
"""
|
||||
|
||||
class Dummy(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
@ -736,7 +759,7 @@ class TestRecursiveScript(JitTestCase):
|
||||
def forward(self, input):
|
||||
return self.encoder(input)
|
||||
|
||||
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
|
||||
self.checkModule(ContainsLoaded(), (torch.rand(2, 3),))
|
||||
|
||||
def test_optional_module(self):
|
||||
class Dummy(nn.Module):
|
||||
|
||||
@ -2,20 +2,23 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from typing import List
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, freeze_rng_state
|
||||
from torch.testing._internal.jit_utils import freeze_rng_state, JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestRemoveMutation(JitTestCase):
|
||||
def test_aten_inplace(self):
|
||||
@ -26,7 +29,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test_not_new_alias)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check("aten::add_").run(graph)
|
||||
self.assertEqual(fn(torch.ones([2, 2])), test_not_new_alias(torch.ones([2, 2])))
|
||||
|
||||
@ -38,7 +41,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
# there is no functional equivalent of x[0] = ...
|
||||
fn = torch.jit.script(test_no_lowering)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check("aten::copy_").run(graph)
|
||||
self.assertEqual(fn(), test_no_lowering())
|
||||
|
||||
@ -50,7 +53,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test_move_before_not_valid)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check("aten::add_").run(graph)
|
||||
self.assertEqual(fn(), test_move_before_not_valid())
|
||||
|
||||
@ -63,7 +66,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test_successful)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check_not("aten::add_").run(graph)
|
||||
self.assertEqual(test_successful(), fn())
|
||||
|
||||
@ -77,7 +80,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
fn = torch.jit.script(test_intermediary_use)
|
||||
graph = fn.graph
|
||||
FileCheck().check_count("aten::add_", 2).run(graph)
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
# Unable to remove the second add_ because of the y = x + 4 use
|
||||
# In the future we could duplicating the value of x as a temporary and replacing
|
||||
# its intermediary use (so long as aliasing is safe)
|
||||
@ -96,7 +99,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
out_eager = foo(torch.tensor(5), True)
|
||||
foo_script = torch.jit.script(foo)
|
||||
FileCheck().check("aten::add_").run(foo_script.graph)
|
||||
self.run_pass('remove_mutation', foo_script.graph)
|
||||
self.run_pass("remove_mutation", foo_script.graph)
|
||||
FileCheck().check_not("aten::add_").run(foo_script.graph)
|
||||
|
||||
self.assertEqual(out_eager, foo_script(torch.tensor(5), True))
|
||||
@ -113,8 +116,8 @@ class TestRemoveMutation(JitTestCase):
|
||||
y = x.add_(2)
|
||||
return y, li
|
||||
|
||||
self.run_pass('inline', foo.graph)
|
||||
self.run_pass('remove_mutation', foo.graph)
|
||||
self.run_pass("inline", foo.graph)
|
||||
self.run_pass("remove_mutation", foo.graph)
|
||||
FileCheck().check("aten::add_").run(foo.graph)
|
||||
|
||||
@torch.jit.script
|
||||
@ -126,8 +129,8 @@ class TestRemoveMutation(JitTestCase):
|
||||
z = x.add_(2)
|
||||
return z
|
||||
|
||||
self.run_pass('inline', foo.graph)
|
||||
self.run_pass('remove_mutation', foo.graph)
|
||||
self.run_pass("inline", foo.graph)
|
||||
self.run_pass("remove_mutation", foo.graph)
|
||||
FileCheck().check("aten::add_").run(foo.graph)
|
||||
|
||||
def test_special_mapped_op(self):
|
||||
@ -140,7 +143,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test_successful)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
|
||||
self.assertEqual(test_successful(), fn())
|
||||
|
||||
@ -154,8 +157,8 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(test_successful)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
FileCheck().check_not('aten::fill_').run(graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check_not("aten::fill_").run(graph)
|
||||
|
||||
def normal():
|
||||
# NOTE: For some unknown reason, the
|
||||
@ -167,7 +170,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(normal)
|
||||
graph = fn.graph
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
FileCheck().check_not("normal_").run(graph)
|
||||
with freeze_rng_state():
|
||||
out_eager = normal()
|
||||
@ -181,10 +184,12 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
fn = torch.jit.script(successful_remove)
|
||||
graph = fn.graph
|
||||
self.run_pass('loop_unrolling', graph)
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass('constant_propagation', graph)
|
||||
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph)
|
||||
self.run_pass("loop_unrolling", graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
self.run_pass("constant_propagation", graph)
|
||||
FileCheck().check("graph").check_next("Constant").check_next("return").run(
|
||||
graph
|
||||
)
|
||||
self.assertEqual(successful_remove(), successful_remove())
|
||||
|
||||
def intermediary_use():
|
||||
@ -196,14 +201,14 @@ class TestRemoveMutation(JitTestCase):
|
||||
fn = torch.jit.script(intermediary_use)
|
||||
graph = fn.graph
|
||||
FileCheck().check("append").run(graph)
|
||||
self.run_pass('remove_mutation', graph)
|
||||
self.run_pass("remove_mutation", graph)
|
||||
# it is possible to remove the append here but don't currently have the logic for it
|
||||
FileCheck().check_not("append").run(graph)
|
||||
self.assertEqual(intermediary_use(), fn())
|
||||
|
||||
def test_lists_insert(self):
|
||||
def successful_remove():
|
||||
a : List[int] = []
|
||||
a: List[int] = []
|
||||
a.insert(0, 1)
|
||||
a.insert(0, 2)
|
||||
a.insert(-10, 3)
|
||||
@ -215,7 +220,9 @@ class TestRemoveMutation(JitTestCase):
|
||||
graph = fn.graph
|
||||
torch._C._jit_pass_remove_mutation(graph)
|
||||
torch._C._jit_pass_constant_propagation(graph)
|
||||
FileCheck().check("graph").check_next("Constant").check_next("return").run(graph)
|
||||
FileCheck().check("graph").check_next("Constant").check_next("return").run(
|
||||
graph
|
||||
)
|
||||
self.assertEqual(successful_remove(), fn())
|
||||
|
||||
def test_list_indexing_removal(self):
|
||||
@ -271,6 +278,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
def test_common_pytorch_list_ops(self):
|
||||
for op in ["cat", "stack", "vstack", "hstack", "dstack"]:
|
||||
|
||||
class OpMod(torch.nn.Module):
|
||||
def __init__(self, op):
|
||||
super().__init__()
|
||||
@ -285,7 +293,7 @@ class TestRemoveMutation(JitTestCase):
|
||||
torch_op = getattr(torch, op)
|
||||
mod = OpMod(torch_op)
|
||||
mod_script = torch.jit.script(mod)
|
||||
self.run_pass('remove_mutation', mod_script.forward.graph)
|
||||
self.run_pass("remove_mutation", mod_script.forward.graph)
|
||||
FileCheck().check_not("aten::add_").run(mod_script.forward.graph)
|
||||
self.assertEqual(mod(), mod_script())
|
||||
|
||||
@ -299,7 +307,6 @@ class TestRemoveMutation(JitTestCase):
|
||||
|
||||
self.assertEqual(sums, [ten.sum() for ten in result])
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def test_multiple_uses():
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
@ -307,5 +314,5 @@ class TestRemoveMutation(JitTestCase):
|
||||
y = [x, x]
|
||||
return torch.cat(y), y
|
||||
|
||||
self.run_pass('remove_mutation', mod_script.forward.graph)
|
||||
self.run_pass("remove_mutation", mod_script.forward.graph)
|
||||
FileCheck().check("aten::add_").run(test_multiple_uses.graph)
|
||||
|
||||
@ -8,12 +8,12 @@ from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.testing._internal.common_utils import TemporaryFileName, skipIfTorchDynamo
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo, TemporaryFileName
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, clear_class_registry
|
||||
from torch.testing._internal.jit_utils import clear_class_registry, JitTestCase
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -439,7 +439,7 @@ class TestSaveLoad(JitTestCase):
|
||||
global FooTuple # see [local resolution in python]
|
||||
|
||||
class FooTuple(NamedTuple):
|
||||
a: 'int'
|
||||
a: "int"
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x: FooTuple) -> torch.Tensor:
|
||||
@ -608,7 +608,6 @@ class TestSaveLoad(JitTestCase):
|
||||
self.assertTrue(m_params["bar.bias"].is_cpu)
|
||||
self.assertTrue(m_loaded_params["bar.bias"].is_cpu)
|
||||
|
||||
|
||||
def test_save_load_with_saved_traced_inputs(self):
|
||||
"""
|
||||
Check that saving and loading with traced inputs works as expected
|
||||
@ -637,14 +636,18 @@ class TestSaveLoad(JitTestCase):
|
||||
# Validate that with no input specified the traced inputs are stored
|
||||
traced_module = torch.jit.trace(module, input_tensor)
|
||||
traced_inputs = list(traced_module.graph.inputs())
|
||||
self.assertEqual(traced_module._c._retrieve_traced_inputs()['forward'], [input_tensor])
|
||||
self.assertEqual(
|
||||
traced_module._c._retrieve_traced_inputs()["forward"], [input_tensor]
|
||||
)
|
||||
with TemporaryFileName() as fname:
|
||||
path = pathlib.Path(fname)
|
||||
traced_module.save(path)
|
||||
loaded_module = torch.jit.load(path, _restore_shapes=True)
|
||||
loaded_inputs = list(loaded_module.graph.inputs())
|
||||
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
|
||||
self.assertEqual(traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes())
|
||||
self.assertEqual(
|
||||
traced_inputs[1].type().sizes(), loaded_inputs[1].type().sizes()
|
||||
)
|
||||
# Validate that if no shapes are requested previous functionality remains
|
||||
loaded_module = torch.jit.load(path)
|
||||
loaded_inputs = list(loaded_module.graph.inputs())
|
||||
@ -672,7 +675,7 @@ class TestSaveLoad(JitTestCase):
|
||||
"1000": (
|
||||
torch.tensor([0]),
|
||||
torch.tensor([], dtype=torch.int64),
|
||||
torch.tensor([])
|
||||
torch.tensor([]),
|
||||
)
|
||||
}
|
||||
traced_inputs, loaded_inputs = get_loaded_inputs(input1)
|
||||
@ -683,28 +686,32 @@ class TestSaveLoad(JitTestCase):
|
||||
"1000": (
|
||||
torch.tensor([0]),
|
||||
torch.tensor([1500000, 1500004], dtype=torch.int64),
|
||||
torch.tensor([2.0, 3.0])
|
||||
torch.tensor([2.0, 3.0]),
|
||||
)
|
||||
}
|
||||
traced_inputs, loaded_inputs = get_loaded_inputs(input2)
|
||||
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
|
||||
|
||||
# Testing list
|
||||
input3 = [torch.tensor([0]),
|
||||
torch.tensor([1500000, 1500004], dtype=torch.int64),
|
||||
torch.tensor([2.0, 3.0])]
|
||||
input3 = [
|
||||
torch.tensor([0]),
|
||||
torch.tensor([1500000, 1500004], dtype=torch.int64),
|
||||
torch.tensor([2.0, 3.0]),
|
||||
]
|
||||
|
||||
traced_inputs, loaded_inputs = get_loaded_inputs(input3)
|
||||
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
|
||||
|
||||
# Testing list of dict of list
|
||||
input4 = [{
|
||||
"1000": (
|
||||
torch.tensor([0]),
|
||||
torch.tensor([1500000, 1500004], dtype=torch.int64),
|
||||
torch.tensor([2.0, 3.0])
|
||||
)
|
||||
}]
|
||||
input4 = [
|
||||
{
|
||||
"1000": (
|
||||
torch.tensor([0]),
|
||||
torch.tensor([1500000, 1500004], dtype=torch.int64),
|
||||
torch.tensor([2.0, 3.0]),
|
||||
)
|
||||
}
|
||||
]
|
||||
|
||||
traced_inputs, loaded_inputs = get_loaded_inputs(input4)
|
||||
self.assertEqual(traced_inputs[1].type(), loaded_inputs[1].type())
|
||||
@ -715,14 +722,17 @@ class TestSaveLoad(JitTestCase):
|
||||
Check if the model with string > 4GB can be loaded.
|
||||
"""
|
||||
import psutil
|
||||
|
||||
if psutil.virtual_memory().available < 60 * 1024 * 1024 * 1024:
|
||||
# Profiled the test execution, and got this number to be safe to run the test
|
||||
self.skipTest("Doesn't have enough memory to run test_save_load_large_string_attribute")
|
||||
self.skipTest(
|
||||
"Doesn't have enough memory to run test_save_load_large_string_attribute"
|
||||
)
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.x = "x" * (2 ** 32 + 1)
|
||||
self.x = "x" * (2**32 + 1)
|
||||
|
||||
def forward(self, i) -> int:
|
||||
return len(self.x) + i.numel()
|
||||
@ -793,12 +803,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
self.add_module("second", torch.jit.load(second_saved_module))
|
||||
self.add_module("first", torch.jit.load(first_saved_module))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
@ -846,12 +852,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
self.add_module("second", torch.jit.load(second_saved_module))
|
||||
self.add_module("first", torch.jit.load(first_saved_module))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
@ -931,12 +933,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
self.add_module("second", torch.jit.load(second_saved_module))
|
||||
self.add_module("first", torch.jit.load(first_saved_module))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first(x)
|
||||
@ -1035,12 +1033,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
class ContainsBoth(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_module(
|
||||
"second", torch.jit.load(second_saved_module)
|
||||
)
|
||||
self.add_module(
|
||||
"first", torch.jit.load(first_saved_module)
|
||||
)
|
||||
self.add_module("second", torch.jit.load(second_saved_module))
|
||||
self.add_module("first", torch.jit.load(first_saved_module))
|
||||
|
||||
def forward(self, x):
|
||||
x, named_tuple_1 = self.first(x)
|
||||
@ -1118,18 +1112,18 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
|
||||
first_script_module = torch.jit.script(Foo())
|
||||
first_saved_module = io.BytesIO()
|
||||
torch.jit.save_jit_module_to_flatbuffer(
|
||||
first_script_module, first_saved_module)
|
||||
torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module)
|
||||
first_saved_module.seek(0)
|
||||
ff_info = torch.jit._serialization.get_flatbuffer_module_info(first_saved_module)
|
||||
self.assertEqual(ff_info['bytecode_version'], 9)
|
||||
self.assertEqual(ff_info['operator_version'], 1)
|
||||
self.assertEqual(ff_info['type_names'], set())
|
||||
self.assertEqual(ff_info['opname_to_num_args'], {'aten::linear': 3})
|
||||
|
||||
self.assertEqual(len(ff_info['function_names']), 1)
|
||||
self.assertTrue(next(iter(ff_info['function_names'])).endswith('forward'))
|
||||
ff_info = torch.jit._serialization.get_flatbuffer_module_info(
|
||||
first_saved_module
|
||||
)
|
||||
self.assertEqual(ff_info["bytecode_version"], 9)
|
||||
self.assertEqual(ff_info["operator_version"], 1)
|
||||
self.assertEqual(ff_info["type_names"], set())
|
||||
self.assertEqual(ff_info["opname_to_num_args"], {"aten::linear": 3})
|
||||
|
||||
self.assertEqual(len(ff_info["function_names"]), 1)
|
||||
self.assertTrue(next(iter(ff_info["function_names"])).endswith("forward"))
|
||||
|
||||
def test_save_load_params_buffers_submodules(self):
|
||||
"""
|
||||
@ -1179,7 +1173,6 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
self.assertEqual(m_name, loaded_name)
|
||||
self.assertEqual(m_buffer, loaded_buffer)
|
||||
|
||||
|
||||
def test_save_load_with_extra_files(self):
|
||||
"""
|
||||
Check that parameters, buffers, and submodules are the same after loading.
|
||||
@ -1194,7 +1187,8 @@ class TestSaveLoadFlatbuffer(JitTestCase):
|
||||
|
||||
extra_files = {"abc.json": b"[1,2,3]"}
|
||||
script_module_io = script_module._save_to_buffer_for_lite_interpreter(
|
||||
_extra_files=extra_files, _use_flatbuffer=True)
|
||||
_extra_files=extra_files, _use_flatbuffer=True
|
||||
)
|
||||
|
||||
re_extra_files = {}
|
||||
torch._C._get_model_extra_files_from_buffer(script_module_io, re_extra_files)
|
||||
|
||||
@ -1,20 +1,21 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from itertools import product as product
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import example, settings, given
|
||||
from itertools import product as product
|
||||
from typing import Union
|
||||
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import torch
|
||||
from hypothesis import example, given, settings
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
@ -23,6 +24,7 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestSaveLoadForOpVersion(JitTestCase):
|
||||
# Helper that returns the module after saving and loading
|
||||
def _save_load_module(self, m):
|
||||
@ -53,7 +55,6 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
node_count = sum(str(n).count(kind) for n in m.graph.nodes())
|
||||
self.assertEqual(node_count, count)
|
||||
|
||||
|
||||
"""
|
||||
Tests that verify Torchscript remaps aten::div(_) from versions 0-3
|
||||
to call either aten::true_divide(_), if an input is a float type,
|
||||
@ -62,16 +63,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
div behavior has not yet been updated.
|
||||
"""
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_tensor(self, sample_input):
|
||||
def historic_div(self, other):
|
||||
if self.is_floating_point() or other.is_floating_point():
|
||||
return self.true_divide(other)
|
||||
return self.divide(other, rounding_mode='trunc')
|
||||
return self.divide(other, rounding_mode="trunc")
|
||||
|
||||
# Tensor x Tensor
|
||||
class MyModule(torch.nn.Module):
|
||||
@ -85,7 +91,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
# Loads historic module
|
||||
try:
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -108,16 +116,21 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
_helper(v3_mobile_module, historic_div)
|
||||
_helper(current_mobile_module, torch.div)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_tensor_inplace(self, sample_input):
|
||||
def historic_div_(self, other):
|
||||
if self.is_floating_point() or other.is_floating_point():
|
||||
return self.true_divide_(other)
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
return self.divide_(other, rounding_mode="trunc")
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, a, b):
|
||||
@ -126,7 +139,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_inplace_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -151,16 +166,25 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
a = torch.tensor((val_a,))
|
||||
_helper(current_mobile_module, torch.Tensor.div_)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_tensor_out(self, sample_input):
|
||||
def historic_div_out(self, other, out):
|
||||
if self.is_floating_point() or other.is_floating_point() or out.is_floating_point():
|
||||
if (
|
||||
self.is_floating_point()
|
||||
or other.is_floating_point()
|
||||
or out.is_floating_point()
|
||||
):
|
||||
return torch.true_divide(self, other, out=out)
|
||||
return torch.divide(self, other, out=out, rounding_mode='trunc')
|
||||
return torch.divide(self, other, out=out, rounding_mode="trunc")
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, a, b, out):
|
||||
@ -168,7 +192,9 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_tensor_out_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -179,6 +205,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
b = torch.tensor((val_b,))
|
||||
|
||||
for out in (torch.empty((1,)), torch.empty((1,), dtype=torch.long)):
|
||||
|
||||
def _helper(m, fn):
|
||||
fn_result = None
|
||||
if fn is torch.div:
|
||||
@ -196,9 +223,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
_helper(v3_mobile_module, historic_div_out)
|
||||
_helper(current_mobile_module, torch.div)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_scalar(self, sample_input):
|
||||
@ -208,7 +240,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
def historic_div_scalar_int(self, other: int):
|
||||
if self.is_floating_point():
|
||||
return torch.true_divide(self, other)
|
||||
return torch.divide(self, other, rounding_mode='trunc')
|
||||
return torch.divide(self, other, rounding_mode="trunc")
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def forward(self, a, b: float):
|
||||
@ -220,9 +252,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/jit/fixtures/test_versioned_div_scalar_float_v2.ptl"
|
||||
)
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_int_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -249,9 +285,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
_helper(v3_mobile_module_int, historic_div_scalar_int)
|
||||
_helper(current_mobile_module_int, torch.div)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_scalar_reciprocal(self, sample_input):
|
||||
@ -261,7 +302,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
def historic_div_scalar_int_reciprocal(self, other: int):
|
||||
if self.is_floating_point():
|
||||
return other / self
|
||||
return torch.divide(other, self, rounding_mode='trunc')
|
||||
return torch.divide(other, self, rounding_mode="trunc")
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def forward(self, a, b: float):
|
||||
@ -273,9 +314,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl"
|
||||
)
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -311,9 +356,14 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
_helper(v3_mobile_module_int, current_mobile_module_int)
|
||||
_helper(current_mobile_module_int, torch.div)
|
||||
|
||||
@settings(max_examples=10, deadline=200000) # A total of 10 examples will be generated
|
||||
@settings(
|
||||
max_examples=10, deadline=200000
|
||||
) # A total of 10 examples will be generated
|
||||
@given(
|
||||
sample_input=st.tuples(st.integers(min_value=5, max_value=199), st.floats(min_value=5.0, max_value=199.0))
|
||||
sample_input=st.tuples(
|
||||
st.integers(min_value=5, max_value=199),
|
||||
st.floats(min_value=5.0, max_value=199.0),
|
||||
)
|
||||
) # Generate a pair (integer, float)
|
||||
@example((2, 3, 2.0, 3.0)) # Ensure this example will be covered
|
||||
def test_versioned_div_scalar_inplace(self, sample_input):
|
||||
@ -324,7 +374,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
if self.is_floating_point():
|
||||
return self.true_divide_(other)
|
||||
|
||||
return self.divide_(other, rounding_mode='trunc')
|
||||
return self.divide_(other, rounding_mode="trunc")
|
||||
|
||||
class MyModuleFloat(torch.nn.Module):
|
||||
def forward(self, a, b: float):
|
||||
@ -338,9 +388,13 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module_float = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl"
|
||||
)
|
||||
v3_mobile_module_int = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
@ -378,14 +432,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
try:
|
||||
v3_mobile_module = _load_for_lite_interpreter(
|
||||
pytorch_test_dir + "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl")
|
||||
pytorch_test_dir
|
||||
+ "/cpp/jit/upgrader_models/test_versioned_div_scalar_scalar_v2.ptl"
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest("Failed to load fixture!")
|
||||
|
||||
current_mobile_module = self._save_load_mobile_module(MyModule)
|
||||
|
||||
def _helper(m, fn):
|
||||
vals = (5., 3, 2., 7)
|
||||
vals = (5.0, 3, 2.0, 7)
|
||||
m_result = m(*vals)
|
||||
fn_result = fn(*vals)
|
||||
for mr, hr in zip(m_result, fn_result):
|
||||
@ -395,13 +451,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
def test_versioned_linspace(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
||||
def forward(
|
||||
self, a: Union[int, float, complex], b: Union[int, float, complex]
|
||||
):
|
||||
c = torch.linspace(a, b, steps=5)
|
||||
d = torch.linspace(a, b, steps=100)
|
||||
return c, d
|
||||
|
||||
scripted_module = torch.jit.load(
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl")
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
|
||||
)
|
||||
|
||||
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -410,7 +469,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
current_mobile_module = self._save_load_mobile_module(Module)
|
||||
|
||||
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
|
||||
for (a, b) in sample_inputs:
|
||||
for a, b in sample_inputs:
|
||||
(output_with_step, output_without_step) = v7_mobile_module(a, b)
|
||||
(current_with_step, current_without_step) = current_mobile_module(a, b)
|
||||
# when no step is given, should have used 100
|
||||
@ -422,10 +481,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
def test_versioned_linspace_out(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
a: Union[int, float, complex],
|
||||
b: Union[int, float, complex],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
return torch.linspace(a, b, steps=100, out=out)
|
||||
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -433,12 +499,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
current_mobile_module = self._save_load_mobile_module(Module)
|
||||
|
||||
sample_inputs = (
|
||||
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
|
||||
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
|
||||
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)),
|
||||
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)),
|
||||
(
|
||||
3,
|
||||
10,
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
),
|
||||
(
|
||||
-10,
|
||||
10,
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
),
|
||||
(
|
||||
4.0,
|
||||
6.0,
|
||||
torch.empty((100,), dtype=torch.float64),
|
||||
torch.empty((100,), dtype=torch.float64),
|
||||
),
|
||||
(
|
||||
3 + 4j,
|
||||
4 + 5j,
|
||||
torch.empty((100,), dtype=torch.complex64),
|
||||
torch.empty((100,), dtype=torch.complex64),
|
||||
),
|
||||
)
|
||||
for (start, end, out_for_old, out_for_new) in sample_inputs:
|
||||
for start, end, out_for_old, out_for_new in sample_inputs:
|
||||
output = v7_mobile_module(start, end, out_for_old)
|
||||
output_current = current_mobile_module(start, end, out_for_new)
|
||||
# when no step is given, should have used 100
|
||||
@ -448,13 +534,16 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
def test_versioned_logspace(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex]):
|
||||
def forward(
|
||||
self, a: Union[int, float, complex], b: Union[int, float, complex]
|
||||
):
|
||||
c = torch.logspace(a, b, steps=5)
|
||||
d = torch.logspace(a, b, steps=100)
|
||||
return c, d
|
||||
|
||||
scripted_module = torch.jit.load(
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl")
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
|
||||
)
|
||||
|
||||
buffer = io.BytesIO(scripted_module._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -463,7 +552,7 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
current_mobile_module = self._save_load_mobile_module(Module)
|
||||
|
||||
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
|
||||
for (a, b) in sample_inputs:
|
||||
for a, b in sample_inputs:
|
||||
(output_with_step, output_without_step) = v8_mobile_module(a, b)
|
||||
(current_with_step, current_without_step) = current_mobile_module(a, b)
|
||||
# when no step is given, should have used 100
|
||||
@ -475,10 +564,17 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
|
||||
def test_versioned_logspace_out(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, a: Union[int, float, complex], b: Union[int, float, complex], out: torch.Tensor):
|
||||
def forward(
|
||||
self,
|
||||
a: Union[int, float, complex],
|
||||
b: Union[int, float, complex],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
return torch.logspace(a, b, steps=100, out=out)
|
||||
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
buffer = io.BytesIO(loaded_model._save_to_buffer_for_lite_interpreter())
|
||||
buffer.seek(0)
|
||||
@ -486,12 +582,32 @@ class TestSaveLoadForOpVersion(JitTestCase):
|
||||
current_mobile_module = self._save_load_mobile_module(Module)
|
||||
|
||||
sample_inputs = (
|
||||
(3, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
|
||||
(-10, 10, torch.empty((100,), dtype=torch.int64), torch.empty((100,), dtype=torch.int64)),
|
||||
(4.0, 6.0, torch.empty((100,), dtype=torch.float64), torch.empty((100,), dtype=torch.float64)),
|
||||
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64), torch.empty((100,), dtype=torch.complex64)),
|
||||
(
|
||||
3,
|
||||
10,
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
),
|
||||
(
|
||||
-10,
|
||||
10,
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
torch.empty((100,), dtype=torch.int64),
|
||||
),
|
||||
(
|
||||
4.0,
|
||||
6.0,
|
||||
torch.empty((100,), dtype=torch.float64),
|
||||
torch.empty((100,), dtype=torch.float64),
|
||||
),
|
||||
(
|
||||
3 + 4j,
|
||||
4 + 5j,
|
||||
torch.empty((100,), dtype=torch.complex64),
|
||||
torch.empty((100,), dtype=torch.complex64),
|
||||
),
|
||||
)
|
||||
for (start, end, out_for_old, out_for_new) in sample_inputs:
|
||||
for start, end, out_for_old, out_for_new in sample_inputs:
|
||||
output = v8_mobile_module(start, end, out_for_old)
|
||||
output_current = current_mobile_module(start, end, out_for_new)
|
||||
# when no step is given, should have used 100
|
||||
|
||||
@ -11,10 +11,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class Sequence(nn.Module):
|
||||
def __init__(self):
|
||||
@ -38,8 +41,8 @@ class Sequence(nn.Module):
|
||||
outputs = torch.cat(outputs, dim=1)
|
||||
return outputs
|
||||
|
||||
class TestScriptProfile(JitTestCase):
|
||||
|
||||
class TestScriptProfile(JitTestCase):
|
||||
def test_basic(self):
|
||||
seq = torch.jit.script(Sequence())
|
||||
p = torch.jit._ScriptProfile()
|
||||
@ -57,6 +60,7 @@ class TestScriptProfile(JitTestCase):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
_ = seq(torch.rand((10, 100)))
|
||||
|
||||
fn()
|
||||
p.disable()
|
||||
|
||||
@ -83,7 +87,7 @@ class TestScriptProfile(JitTestCase):
|
||||
seq = Sequence()
|
||||
|
||||
@torch.jit.script
|
||||
def fn(max : int):
|
||||
def fn(max: int):
|
||||
_ = seq(torch.rand((10, max)))
|
||||
|
||||
p = torch.jit._ScriptProfile()
|
||||
|
||||
@ -3,22 +3,24 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
|
||||
# NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
|
||||
# reassigning a non-empty Tuple to an attribute previously typed
|
||||
# as containing an empty Tuple SHOULD fail. See note in `_check.py`
|
||||
@ -81,7 +83,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
|
||||
def test_annotated_class_level_annotation_only(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
x: List[int]
|
||||
|
||||
def __init__(self):
|
||||
@ -96,10 +97,8 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.checkModule(M(), ([1, 2, 3],))
|
||||
assert len(w) == 0
|
||||
|
||||
|
||||
def test_annotated_class_level_annotation_and_init_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
x: List[int]
|
||||
|
||||
def __init__(self):
|
||||
@ -116,7 +115,6 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
|
||||
def test_annotated_class_level_jit_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
x: List[int]
|
||||
|
||||
def __init__(self):
|
||||
@ -141,12 +139,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Tried to set nonexistent attribute",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_empty_dict(self):
|
||||
@ -159,12 +160,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Tried to set nonexistent attribute",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_empty_optional(self):
|
||||
@ -177,12 +181,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Wrong type for attribute assignment",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_jit_empty_list(self):
|
||||
@ -195,12 +202,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Tried to set nonexistent attribute",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_jit_empty_dict(self):
|
||||
@ -213,12 +223,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Tried to set nonexistent attribute",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_jit_empty_optional(self):
|
||||
@ -231,12 +244,15 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Wrong type for attribute assignment",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
def test_annotated_with_torch_jit_import(self):
|
||||
@ -251,10 +267,13 @@ class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
|
||||
self.x = x
|
||||
return 1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Wrong type for attribute assignment",
|
||||
"self.x = x"):
|
||||
with self.assertWarnsRegex(UserWarning, "doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Wrong type for attribute assignment", "self.x = x"
|
||||
):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"doesn't support "
|
||||
"instance-level annotations on "
|
||||
"empty non-base types",
|
||||
):
|
||||
torch.jit.script(M())
|
||||
|
||||
@ -2,19 +2,22 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
# Tests that Python slice class is supported in TorchScript
|
||||
class TestSlice(JitTestCase):
|
||||
@ -22,7 +25,9 @@ class TestSlice(JitTestCase):
|
||||
def slice_kwarg(x: List[int]):
|
||||
return x[slice(1, stop=2)]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Slice does not accept any keyword arguments"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Slice does not accept any keyword arguments"
|
||||
):
|
||||
torch.jit.script(slice_kwarg)
|
||||
|
||||
def test_slice_three_nones(self):
|
||||
@ -46,11 +51,13 @@ class TestSlice(JitTestCase):
|
||||
def test_slice_stop_only(self):
|
||||
def fn(x: List[int]):
|
||||
return x[slice(5)]
|
||||
|
||||
self.checkScript(fn, (range(10),))
|
||||
|
||||
def test_slice_stop_only_with_nones(self):
|
||||
def fn(x: List[int]):
|
||||
return x[slice(None, 5, None)]
|
||||
|
||||
self.checkScript(fn, (range(10),))
|
||||
|
||||
def test_slice_start_stop(self):
|
||||
@ -136,8 +143,8 @@ class TestSlice(JitTestCase):
|
||||
num_outputs = {len(x.output().type().elements()) for x in slices}
|
||||
# there should be only one tupleSlice with length of 2
|
||||
self.assertTrue(num_outputs == {2})
|
||||
self.run_pass('lower_all_tuples', tuple_graph)
|
||||
self.assertTrue('Tuple' not in str(tuple_graph))
|
||||
self.run_pass("lower_all_tuples", tuple_graph)
|
||||
self.assertTrue("Tuple" not in str(tuple_graph))
|
||||
|
||||
def test_module_list_slicing(self):
|
||||
class Bar(torch.nn.Module):
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
@ -70,9 +71,7 @@ class TestSparse(JitTestCase):
|
||||
self.a = torch.rand(4, 4).to_sparse_csr()
|
||||
self.b = torch.rand(4, 4).to_sparse_csr()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return x.matmul(self.a).matmul(self.b)
|
||||
|
||||
x = torch.rand(4, 4).to_sparse_csr()
|
||||
|
||||
@ -2,65 +2,76 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestStringFormatting(JitTestCase):
|
||||
|
||||
def test_modulo_operator(self):
|
||||
def fn(dividend: int, divisor: int) -> int:
|
||||
return dividend % divisor
|
||||
|
||||
self.checkScript(fn, (5, 2))
|
||||
|
||||
def test_string_interpolation_with_string_placeholder_and_string_variable(self):
|
||||
def fn(arg1: str):
|
||||
return "%s in template" % arg1
|
||||
|
||||
self.checkScript(fn, ("foo",))
|
||||
|
||||
def test_string_interpolation_with_string_placeholder_and_format_string_variable(self):
|
||||
def test_string_interpolation_with_string_placeholder_and_format_string_variable(
|
||||
self,
|
||||
):
|
||||
def fn(arg1: str):
|
||||
return arg1 % "foo"
|
||||
|
||||
self.checkScript(fn, ("%s in template",))
|
||||
|
||||
def test_string_interpolation_with_double_percent_in_string(self):
|
||||
def fn(arg1: str):
|
||||
return "%s in template %%" % arg1
|
||||
|
||||
self.checkScript(fn, ("foo",))
|
||||
|
||||
def test_string_interpolation_with_percent_in_string(self):
|
||||
@torch.jit.script
|
||||
def fn(arg1: str) -> str:
|
||||
return "%s in template %" % arg1 # noqa: F501
|
||||
return "%s in template %" % arg1 # noqa: F501
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Incomplete format specifier",
|
||||
"\"%s in template %\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Incomplete format specifier", '"%s in template %" % arg1'
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
def test_string_interpolation_with_string_placeholder_and_digit_variable(self):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%s in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_digit_placeholder_and_digit_variable(self):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%d in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_alternate_digit_placeholder(self):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%i in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_digit_placeholder_and_string_variable(self):
|
||||
@ -68,9 +79,11 @@ class TestStringFormatting(JitTestCase):
|
||||
def fn(arg1: str) -> str:
|
||||
return "%d in template" % arg1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"%d requires a number for formatting, but got String",
|
||||
"\"%d in template\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"%d requires a number for formatting, but got String",
|
||||
'"%d in template" % arg1',
|
||||
):
|
||||
fn("1")
|
||||
|
||||
def test_string_interpolation_with_exponent_placeholder_and_string_variable(self):
|
||||
@ -78,39 +91,51 @@ class TestStringFormatting(JitTestCase):
|
||||
def fn(arg1: str) -> str:
|
||||
return "%e in template" % arg1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"%e requires a number for formatting, but got String",
|
||||
"\"%e in template\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"%e requires a number for formatting, but got String",
|
||||
'"%e in template" % arg1',
|
||||
):
|
||||
fn("1")
|
||||
|
||||
def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(self):
|
||||
def test_string_interpolation_with_lowercase_exponent_placeholder_and_digit_variable(
|
||||
self,
|
||||
):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%e in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(self):
|
||||
def test_string_interpolation_with_capital_exponent_placeholder_and_digit_variable(
|
||||
self,
|
||||
):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%E in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_float_placeholder_and_float_variable(self):
|
||||
def fn(arg1: float) -> str:
|
||||
return "%f in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1.0,))
|
||||
|
||||
def test_string_interpolation_with_float_placeholder_and_digit_variable(self):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%f in template" % arg1
|
||||
|
||||
self.checkScript(fn, (1,))
|
||||
|
||||
def test_string_interpolation_with_char_placeholder_and_char_variable(self):
|
||||
def fn(arg1: str) -> str:
|
||||
return "%c in template" % arg1
|
||||
|
||||
self.checkScript(fn, ("a",))
|
||||
|
||||
def test_string_interpolation_with_char_placeholder_and_digit_variable(self):
|
||||
def fn(arg1: int) -> str:
|
||||
return "%c in template" % arg1
|
||||
|
||||
self.checkScript(fn, (97,))
|
||||
|
||||
def test_string_interpolation_with_char_placeholder_and_true_string_variable(self):
|
||||
@ -118,19 +143,23 @@ class TestStringFormatting(JitTestCase):
|
||||
def fn(arg1: str) -> str:
|
||||
return "%c in template" % arg1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"%c requires an int or char for formatting, but got String",
|
||||
"\"%c in template\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"%c requires an int or char for formatting, but got String",
|
||||
'"%c in template" % arg1',
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
def test_string_interpolation_with_multiple_placeholders(self):
|
||||
def fn(arg1: str, arg2: int, arg3: float) -> str:
|
||||
return "%s %d %f in template" % (arg1, arg2, arg3)
|
||||
|
||||
self.checkScript(fn, ("foo", 1, 1))
|
||||
|
||||
def test_string_interpolation_with_subscript(self):
|
||||
def fn(arg1: List[str]) -> str:
|
||||
return "%s in template" % arg1[0]
|
||||
|
||||
self.checkScript(fn, (["foo", "bar"],))
|
||||
|
||||
def test_string_interpolation_with_too_few_arguments(self):
|
||||
@ -138,27 +167,33 @@ class TestStringFormatting(JitTestCase):
|
||||
def fn(arg1: str) -> str:
|
||||
return "%s %s in template" % arg1
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Too few arguments for format string",
|
||||
"\"%s %s in template\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Too few arguments for format string",
|
||||
'"%s %s in template" % arg1',
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
def test_string_interpolation_with_too_many_arguments(self):
|
||||
@torch.jit.script
|
||||
def fn(arg1: str, arg2: str) -> str:
|
||||
return "%s in template" % (arg1, arg2) # noqa: F507
|
||||
return "%s in template" % (arg1, arg2) # noqa: F507
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"Too many arguments for format string",
|
||||
"\"%s in template\" % (arg1, arg2"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Too many arguments for format string",
|
||||
'"%s in template" % (arg1, arg2',
|
||||
):
|
||||
fn("foo", "bar")
|
||||
|
||||
def test_string_interpolation_with_unknown_format_specifier(self):
|
||||
@torch.jit.script
|
||||
def fn(arg1: str) -> str:
|
||||
return "%a in template" % arg1 # noqa: F501
|
||||
return "%a in template" % arg1 # noqa: F501
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"The specifier %a is not supported in TorchScript format strings",
|
||||
"\"%a in template\" % arg1"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"The specifier %a is not supported in TorchScript format strings",
|
||||
'"%a in template" % arg1',
|
||||
):
|
||||
fn("foo")
|
||||
|
||||
@ -3,29 +3,36 @@
|
||||
import operator
|
||||
import unittest
|
||||
from textwrap import dedent
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat
|
||||
from torch.testing._internal.common_utils import make_tensor
|
||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||
from typing import List, Any
|
||||
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
# XXX: still in prototype
|
||||
class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def setUp(self):
|
||||
super(JitTestCase, self).setUp()
|
||||
self.prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
|
||||
self.prev_symbolic_shapes_test_enabled = (
|
||||
torch._C._jit_symbolic_shapes_test_mode_enabled()
|
||||
)
|
||||
torch._C._jit_set_symbolic_shapes_test_mode(True)
|
||||
|
||||
def tearDown(self):
|
||||
torch._C._jit_set_symbolic_shapes_test_mode(self.prev_symbolic_shapes_test_enabled)
|
||||
torch._C._jit_set_symbolic_shapes_test_mode(
|
||||
self.prev_symbolic_shapes_test_enabled
|
||||
)
|
||||
|
||||
def test_shape_analysis(self):
|
||||
@torch.jit.script
|
||||
@ -115,7 +122,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def neg_to_one(li):
|
||||
return [elem if elem >= 0 else -1 for elem in li]
|
||||
|
||||
self.assertEqual(neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1])
|
||||
self.assertEqual(
|
||||
neg_to_one(view.output().type().symbolic_sizes()), [-1, 3, 2, -1]
|
||||
)
|
||||
if_out = next(foo.graph.findNode("prim::If").outputs())
|
||||
self.assertEqual(neg_to_one(if_out.type().symbolic_sizes()), [-1, 3, -1, -1])
|
||||
|
||||
@ -135,9 +144,7 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
y = x.mul_(2)
|
||||
return y
|
||||
|
||||
unary_ops = [
|
||||
mul_inplace
|
||||
]
|
||||
unary_ops = [mul_inplace]
|
||||
for fn in unary_ops:
|
||||
# t = torch.jit.trace(fn, torch.rand([4, 4])) # For some reason tracing is erroring out.
|
||||
t = torch.jit.script(fn)
|
||||
@ -202,7 +209,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
|
||||
inputs[1].setType(inputs[1].type().with_sizes([5, 8, sym1]))
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(graph)
|
||||
self.assertEqual(next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1])
|
||||
self.assertEqual(
|
||||
next(graph.outputs()).type().symbolic_sizes(), [5, 8, sym1]
|
||||
)
|
||||
|
||||
def test_adaptive_avg_pool2d(self):
|
||||
inps = [
|
||||
@ -227,25 +236,105 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True)
|
||||
|
||||
def test_conv_deconv(self):
|
||||
for inp_shape, weight_shape, bias, stride, padding, output_padding, dilation, groups, mod in [
|
||||
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
|
||||
([32, 16, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv_transpose1d),
|
||||
([1, 32, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv2d),
|
||||
([1, 30, 5, 10], [30, 16, 3, 3], None, [2, 2], [0, 0], 0, 1, 2, torch.nn.functional.conv_transpose2d),
|
||||
([3, 14, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 2, 1, 2, torch.nn.functional.conv3d),
|
||||
([3, 2, 10, 66, 55], [2, 7, 7, 4, 4], None, 1, 1, 0, 1, 2, torch.nn.functional.conv_transpose3d)]:
|
||||
for (
|
||||
inp_shape,
|
||||
weight_shape,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dilation,
|
||||
groups,
|
||||
mod,
|
||||
) in [
|
||||
([32, 6, 10], [16, 3, 3], None, 2, 2, 1, 1, 2, torch.nn.functional.conv1d),
|
||||
(
|
||||
[32, 16, 10],
|
||||
[16, 3, 3],
|
||||
None,
|
||||
2,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
torch.nn.functional.conv_transpose1d,
|
||||
),
|
||||
(
|
||||
[1, 32, 5, 10],
|
||||
[30, 16, 3, 3],
|
||||
None,
|
||||
[2, 2],
|
||||
[0, 0],
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
torch.nn.functional.conv2d,
|
||||
),
|
||||
(
|
||||
[1, 30, 5, 10],
|
||||
[30, 16, 3, 3],
|
||||
None,
|
||||
[2, 2],
|
||||
[0, 0],
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
torch.nn.functional.conv_transpose2d,
|
||||
),
|
||||
(
|
||||
[3, 14, 10, 66, 55],
|
||||
[2, 7, 7, 4, 4],
|
||||
None,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
2,
|
||||
torch.nn.functional.conv3d,
|
||||
),
|
||||
(
|
||||
[3, 2, 10, 66, 55],
|
||||
[2, 7, 7, 4, 4],
|
||||
None,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
torch.nn.functional.conv_transpose3d,
|
||||
),
|
||||
]:
|
||||
inp = torch.rand(inp_shape)
|
||||
weight = torch.rand(weight_shape)
|
||||
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]:
|
||||
if mod in [
|
||||
torch.nn.functional.conv1d,
|
||||
torch.nn.functional.conv2d,
|
||||
torch.nn.functional.conv3d,
|
||||
]:
|
||||
res = mod(inp, weight, bias, stride, padding, dilation, groups).size()
|
||||
else:
|
||||
res = mod(inp, weight, bias, stride, padding, output_padding, dilation, groups).size()
|
||||
res = mod(
|
||||
inp, weight, bias, stride, padding, output_padding, dilation, groups
|
||||
).size()
|
||||
|
||||
def foo(inp, weight):
|
||||
if mod in [torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d]:
|
||||
if mod in [
|
||||
torch.nn.functional.conv1d,
|
||||
torch.nn.functional.conv2d,
|
||||
torch.nn.functional.conv3d,
|
||||
]:
|
||||
return mod(inp, weight, bias, stride, padding, dilation, groups)
|
||||
else:
|
||||
return mod(inp, weight, bias, stride, padding, output_padding, dilation, groups)
|
||||
return mod(
|
||||
inp,
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dilation,
|
||||
groups,
|
||||
)
|
||||
|
||||
fn = torch.jit.trace(foo, (inp, weight))
|
||||
torch._C._jit_erase_non_input_shape_information(fn.graph)
|
||||
@ -280,33 +369,58 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
]
|
||||
|
||||
for inp in inps:
|
||||
funcs_template = dedent('''
|
||||
funcs_template = dedent(
|
||||
"""
|
||||
def func():
|
||||
return torch.arange({args})
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
inp_s = str(inp)[1:-1] # remove tuple parens
|
||||
funcs_str = funcs_template.format(args=inp_s)
|
||||
scope = {}
|
||||
execWrapper(funcs_str, globals(), scope)
|
||||
cu = torch.jit.CompilationUnit(funcs_str)
|
||||
self.checkShapeAnalysis(list(cu.func().size()), cu.func.graph, assert_propagation=True, constant_prop=False)
|
||||
self.checkShapeAnalysis(
|
||||
list(cu.func().size()),
|
||||
cu.func.graph,
|
||||
assert_propagation=True,
|
||||
constant_prop=False,
|
||||
)
|
||||
|
||||
def test_shape_embedding_bag(self):
|
||||
# TODO: merge into opinfos, having difficulties there
|
||||
with torch.no_grad():
|
||||
|
||||
def make_arg(shape, low=None, high=None):
|
||||
return make_tensor(shape, device='cpu', dtype=torch.int64,
|
||||
low=low, high=high, requires_grad=False)
|
||||
return make_tensor(
|
||||
shape,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
low=low,
|
||||
high=high,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
nn_inps = (
|
||||
(make_arg((40,), 0, 9), torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0)),
|
||||
(
|
||||
make_arg((40,), 0, 9),
|
||||
torch.nn.Embedding(20, embedding_dim=64, max_norm=1.0),
|
||||
),
|
||||
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 20, sparse=True)),
|
||||
(make_arg((0,)), torch.nn.Embedding(0, 0, sparse=True)),
|
||||
(make_arg((2, 4), 0, 9), torch.nn.Embedding(10, 0, sparse=True)),
|
||||
(make_arg((4,), 0, 21), torch.nn.Embedding(22, 5, max_norm=1.0)),
|
||||
(make_arg((2,), 0, 1), torch.nn.Embedding.from_pretrained(torch.arange(6.).view(2, 3), max_norm=2.,
|
||||
norm_type=.5, scale_grad_by_freq=False, sparse=True)),
|
||||
(
|
||||
make_arg((2,), 0, 1),
|
||||
torch.nn.Embedding.from_pretrained(
|
||||
torch.arange(6.0).view(2, 3),
|
||||
max_norm=2.0,
|
||||
norm_type=0.5,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
for inp, module in nn_inps:
|
||||
@ -326,14 +440,16 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
|
||||
fn = torch.jit.trace(foo, (inp.detach(),), check_trace=False)
|
||||
|
||||
self.checkShapeAnalysis(out_size, fn.graph, assert_propagation=True, constant_prop=False)
|
||||
self.checkShapeAnalysis(
|
||||
out_size, fn.graph, assert_propagation=True, constant_prop=False
|
||||
)
|
||||
|
||||
def test_shape_concat(self):
|
||||
# TODO: unify with opinfo tests, traces of lists dont preserve sizes in IR
|
||||
sample_inputs = sample_inputs_cat_concat(None, "cpu", torch.float, False)
|
||||
|
||||
class CatMod(nn.Module):
|
||||
__constants__ = ['dim']
|
||||
__constants__ = ["dim"]
|
||||
|
||||
def __init__(self, dim=0):
|
||||
super().__init__()
|
||||
@ -374,16 +490,23 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
# Also, as the return shapes are the input, weight, and bias shape, there is no point
|
||||
# in a really complicated test
|
||||
|
||||
input = torch.randn((16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
weight = torch.randn((8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True)
|
||||
input = torch.randn(
|
||||
(16, 16, 8, 8), dtype=torch.float32, device="cpu", requires_grad=True
|
||||
)
|
||||
weight = torch.randn(
|
||||
(8, 4, 3, 3), dtype=torch.float32, device="cpu", requires_grad=True
|
||||
)
|
||||
out_grad = torch.randn((16, 8, 8, 8), dtype=torch.float32, device="cpu")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def conv_bwd(input, weight, grad):
|
||||
bias_sizes = [8, ]
|
||||
bias_sizes = [
|
||||
8,
|
||||
]
|
||||
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
|
||||
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad, input, weight, bias_sizes, *args
|
||||
)
|
||||
|
||||
self.assert_shape_equal_scripted(conv_bwd, (input, weight, out_grad))
|
||||
|
||||
@ -391,15 +514,19 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def conv_bwd_2(input, weight, grad):
|
||||
bias_sizes = None
|
||||
args = ([1, 1], [1, 1], [1, 1], False, [0, 0], 4, [True, True, True])
|
||||
return torch.ops.aten.convolution_backward(grad, input, weight, bias_sizes, *args)
|
||||
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad, input, weight, bias_sizes, *args
|
||||
)
|
||||
|
||||
self.assert_shape_equal_scripted(conv_bwd_2, (input, weight, out_grad))
|
||||
|
||||
def test_returning_input_symbolic_shapes(self):
|
||||
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
|
||||
inps = list(mm.graph.inputs())
|
||||
inps[1].setType(inps[1].type().with_sizes([None, None, None, None]))
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
|
||||
shape_compute_graph = (
|
||||
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
|
||||
)
|
||||
g = shape_compute_graph.partial_eval_shape_graph()
|
||||
# to make into a jit function cant have multiple outputs
|
||||
g.makeMultiOutputIntoTuple()
|
||||
@ -412,8 +539,12 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
|
||||
def test_partial_eval_graph_conv(self):
|
||||
mm = torch.jit.freeze(torch.jit.script(nn.Conv2d(16, 33, 3, stride=2).eval()))
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
|
||||
output_sizes = mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
|
||||
shape_compute_graph = (
|
||||
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mm.graph)
|
||||
)
|
||||
output_sizes = (
|
||||
mm.graph.findNode("aten::conv2d").output().type().symbolic_sizes()
|
||||
)
|
||||
# calculating 0, 2 and 3 index
|
||||
for i in [0, 2, 3]:
|
||||
self.assertTrue(output_sizes[i] < 0)
|
||||
@ -428,7 +559,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
|
||||
self.assertEqual(o, oe)
|
||||
|
||||
def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs):
|
||||
def checkSymShapeCompute(
|
||||
self, shape_compute_graph, nodes, node_output_sizes, shape_inputs
|
||||
):
|
||||
g = shape_compute_graph.partial_eval_shape_graph()
|
||||
self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
|
||||
output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
|
||||
@ -451,27 +584,49 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
|
||||
|
||||
def test_partial_eval_stitching(self):
|
||||
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
||||
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
||||
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
||||
conv1 = torch.nn.Conv2d(
|
||||
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
|
||||
)
|
||||
max_pool = torch.nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False
|
||||
)
|
||||
conv2 = nn.Conv2d(
|
||||
64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
|
||||
)
|
||||
|
||||
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
|
||||
mod = torch.jit.freeze(
|
||||
torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval())
|
||||
)
|
||||
|
||||
conv1_output = conv1(torch.rand(1, 3, 224, 224))
|
||||
max_pool_output = max_pool(conv1_output)
|
||||
conv2_output = conv2(max_pool_output)
|
||||
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
|
||||
output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()]
|
||||
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],))
|
||||
shape_compute_graph = (
|
||||
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
)
|
||||
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(
|
||||
mod.graph.findAllNodes("aten::conv2d")
|
||||
)
|
||||
output_shapes = [
|
||||
max_pool_output.size(),
|
||||
conv1_output.size(),
|
||||
conv2_output.size(),
|
||||
]
|
||||
self.checkSymShapeCompute(
|
||||
shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],)
|
||||
)
|
||||
|
||||
def test_refinement_through_graph_stitching(self):
|
||||
class TwoConvs(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
||||
self.conv2 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
|
||||
)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.conv1(x)
|
||||
@ -495,18 +650,29 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
self.assertEqual(out1, out2)
|
||||
|
||||
def test_stitching_multi_output(self):
|
||||
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, return_indices=True)
|
||||
max_pool = torch.nn.MaxPool2d(
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
ceil_mode=False,
|
||||
return_indices=True,
|
||||
)
|
||||
tensor = torch.rand(1, 3, 224, 224)
|
||||
mod = torch.jit.trace(max_pool, (tensor,))
|
||||
mod = torch.jit.freeze(mod.eval())
|
||||
inp = list(mod.graph.inputs())[1]
|
||||
inp.setType(inp.type().with_sizes([None, None, None, None]))
|
||||
output_tensor = list(mod(tensor)[0].size())
|
||||
self.run_pass('lower_all_tuples', mod.graph)
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
self.run_pass("lower_all_tuples", mod.graph)
|
||||
shape_compute_graph = (
|
||||
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
)
|
||||
max_pool_node = mod.graph.findNode("aten::max_pool2d_with_indices")
|
||||
outs = list(max_pool_node.outputs())
|
||||
self.assertEqual(outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes())
|
||||
self.assertEqual(
|
||||
outs[0].type().symbolic_sizes(), outs[1].type().symbolic_sizes()
|
||||
)
|
||||
g = shape_compute_graph.partial_eval_shape_graph()
|
||||
# to make into a jit function cant have multiple outputs
|
||||
g.makeMultiOutputIntoTuple()
|
||||
@ -528,7 +694,6 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
self.assertEqual(out, [-2, -3])
|
||||
|
||||
def test_stitching_concat(self):
|
||||
|
||||
@torch.jit.script
|
||||
def foo1(a, b, x, y):
|
||||
return (a / b) + torch.cat([x, y])
|
||||
@ -542,15 +707,25 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
for inp in foo.graph.inputs():
|
||||
inp.setType(inp.type().with_sizes([None, None]))
|
||||
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
|
||||
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]
|
||||
shape_compute_graph = (
|
||||
torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(
|
||||
foo.graph
|
||||
)
|
||||
)
|
||||
nodes = (
|
||||
[g.findNode("aten::div")]
|
||||
+ [g.findNode("aten::add")]
|
||||
+ [g.findNode("aten::cat")]
|
||||
)
|
||||
|
||||
inps = [1, 10], [20, 10], [15, 1], [5, 1]
|
||||
output_shapes = [[20, 10], [20, 10], [20, 1]]
|
||||
|
||||
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
|
||||
|
||||
@unittest.skipIf(not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python")
|
||||
@unittest.skipIf(
|
||||
not hasattr(torch.jit, "_shapes"), "shape functions not loaded in python"
|
||||
)
|
||||
def test_shape_function_includes(self):
|
||||
inp_shape = [1, 16, 5, 10]
|
||||
weight_shape = [33, 16, 3, 3]
|
||||
@ -559,7 +734,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
padding = [0, 0]
|
||||
dilation = [1, 1]
|
||||
groups = 1
|
||||
res = torch.jit._shapes.conv2d(inp_shape, weight_shape, bias, stride, padding, dilation, groups)
|
||||
res = torch.jit._shapes.conv2d(
|
||||
inp_shape, weight_shape, bias, stride, padding, dilation, groups
|
||||
)
|
||||
self.assertEqual(res, [1, 33, 2, 4])
|
||||
|
||||
m1_shape = [10, 20]
|
||||
@ -580,8 +757,11 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
def wrong_input_types(x, y):
|
||||
x: List[int] = []
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected supertype of int"):
|
||||
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_input_types.graph)
|
||||
torch._C._jit_register_shape_compute_graph_for_node(
|
||||
node, wrong_input_types.graph
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def wrong_output_types(x: List[int], y: List[int]):
|
||||
@ -589,7 +769,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "but got graph_type"):
|
||||
torch._C._jit_register_shape_compute_graph_for_node(node, wrong_output_types.graph)
|
||||
torch._C._jit_register_shape_compute_graph_for_node(
|
||||
node, wrong_output_types.graph
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def too_many_inputs(x: List[int], y: List[int], z: Any, z2: Any):
|
||||
@ -597,7 +779,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
return x
|
||||
|
||||
with self.assertRaises(RuntimeError) as error:
|
||||
torch._C._jit_register_shape_compute_graph_for_node(node, too_many_inputs.graph)
|
||||
torch._C._jit_register_shape_compute_graph_for_node(
|
||||
node, too_many_inputs.graph
|
||||
)
|
||||
|
||||
self.assertTrue("fewer arguments than schema" in str(error.exception))
|
||||
|
||||
@ -608,9 +792,22 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
||||
|
||||
inputs = list(foo.graph.inputs())
|
||||
inputs[0].setType(inputs[0].type().with_sizes([8, 2]))
|
||||
inputs[1].setType(inputs[1].type().with_sizes([8,]))
|
||||
inputs[1].setType(
|
||||
inputs[1]
|
||||
.type()
|
||||
.with_sizes(
|
||||
[
|
||||
8,
|
||||
]
|
||||
)
|
||||
)
|
||||
torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)
|
||||
self.assertEqual(next(foo.graph.outputs()).type().sizes(), [8,])
|
||||
self.assertEqual(
|
||||
next(foo.graph.outputs()).type().sizes(),
|
||||
[
|
||||
8,
|
||||
],
|
||||
)
|
||||
|
||||
def test_squeeze_dims(self):
|
||||
@torch.jit.script
|
||||
|
||||
@ -10,10 +10,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorCreationOps(JitTestCase):
|
||||
"""
|
||||
@ -27,7 +30,7 @@ class TestTensorCreationOps(JitTestCase):
|
||||
# as integers, which are not comparable against eager torch.dtype.
|
||||
assert perm.dtype == torch.int64
|
||||
|
||||
self.checkScript(randperm, (3, ))
|
||||
self.checkScript(randperm, (3,))
|
||||
|
||||
def test_randperm_specifed_dtype(self):
|
||||
def randperm(x: int):
|
||||
@ -36,7 +39,7 @@ class TestTensorCreationOps(JitTestCase):
|
||||
# as integers, which are not comparable against eager torch.dtype.
|
||||
assert perm.dtype == torch.float
|
||||
|
||||
self.checkScript(randperm, (3, ))
|
||||
self.checkScript(randperm, (3,))
|
||||
|
||||
def test_triu_indices_default_dtype(self):
|
||||
def triu_indices(rows: int, cols: int):
|
||||
|
||||
@ -8,8 +8,8 @@ import torch
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
@ -18,6 +18,7 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTensorMethods(JitTestCase):
|
||||
def test_getitem(self):
|
||||
def tensor_getitem(inp: torch.Tensor):
|
||||
@ -25,7 +26,7 @@ class TestTensorMethods(JitTestCase):
|
||||
return inp.__getitem__(indices)
|
||||
|
||||
inp = torch.rand(3, 4)
|
||||
self.checkScript(tensor_getitem, (inp, ))
|
||||
self.checkScript(tensor_getitem, (inp,))
|
||||
|
||||
scripted = torch.jit.script(tensor_getitem)
|
||||
FileCheck().check("aten::index").run(scripted.graph)
|
||||
@ -35,5 +36,6 @@ class TestTensorMethods(JitTestCase):
|
||||
return inp.__getitem__()
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"):
|
||||
RuntimeError, "expected exactly 1 argument", "inp.__getitem__"
|
||||
):
|
||||
torch.jit.script(tensor_getitem_invalid)
|
||||
|
||||
@ -1,27 +1,27 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import copy
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
from torch.testing._internal.common_utils import skipIfTorchDynamo
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
find_library_location,
|
||||
)
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
@ -30,14 +30,15 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("skipping as a precaution")
|
||||
class TestTorchbind(JitTestCase):
|
||||
def setUp(self):
|
||||
if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
lib_file_path = find_library_location('libtorchbind_test.so')
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location('torchbind_test.dll')
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
def test_torchbind(self):
|
||||
@ -50,15 +51,17 @@ class TestTorchbind(JitTestCase):
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
|
||||
val = torch.classes._TorchScriptTesting._Foo(5, 3)
|
||||
val.increment('foo')
|
||||
val.increment("foo")
|
||||
|
||||
def f():
|
||||
ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
|
||||
return ss.pop()
|
||||
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
def f():
|
||||
@ -66,6 +69,7 @@ class TestTorchbind(JitTestCase):
|
||||
ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
|
||||
ss1.push(ss2.pop())
|
||||
return ss1.pop() + ss2.pop()
|
||||
|
||||
test_equality(f, lambda x: x)
|
||||
|
||||
# test nn module with prepare_scriptable function
|
||||
@ -116,8 +120,11 @@ class TestTorchbind(JitTestCase):
|
||||
scripted = torch.jit.script(foo)
|
||||
# Ensure we are creating the object and calling __init__
|
||||
# rather than calling the __init__wrapper nonsense
|
||||
fc = FileCheck().check('prim::CreateObject()')\
|
||||
.check('prim::CallMethod[name="__init__"]')
|
||||
fc = (
|
||||
FileCheck()
|
||||
.check("prim::CreateObject()")
|
||||
.check('prim::CallMethod[name="__init__"]')
|
||||
)
|
||||
fc.run(str(scripted.graph))
|
||||
out = scripted()
|
||||
self.assertEqual(out.pop(), "mom")
|
||||
@ -167,7 +174,7 @@ class TestTorchbind(JitTestCase):
|
||||
out, result = scripted()
|
||||
self.assertEqual(result, 10)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'can\'t set attribute'):
|
||||
with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
|
||||
out.y = 5
|
||||
|
||||
def foo_not_setter():
|
||||
@ -177,9 +184,11 @@ class TestTorchbind(JitTestCase):
|
||||
# getY method intentionally adds 4 to x
|
||||
return fooGetterSetter.y
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
'Tried to set read-only attribute: y',
|
||||
'fooGetterSetter.y = old + 4'):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Tried to set read-only attribute: y",
|
||||
"fooGetterSetter.y = old + 4",
|
||||
):
|
||||
scripted = torch.jit.script(foo_not_setter)
|
||||
|
||||
def test_torchbind_def_property_readwrite(self):
|
||||
@ -196,9 +205,9 @@ class TestTorchbind(JitTestCase):
|
||||
fooReadWrite.y = 5
|
||||
return fooReadWrite
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
'Tried to set read-only attribute: y',
|
||||
'fooReadWrite.y = 5'):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
|
||||
):
|
||||
scripted = torch.jit.script(foo_readwrite_error)
|
||||
|
||||
def test_torchbind_take_instance_as_method_arg(self):
|
||||
@ -250,7 +259,9 @@ class TestTorchbind(JitTestCase):
|
||||
return self.foo_mod.info()
|
||||
|
||||
def to_ivalue(self):
|
||||
torchbind_model = torch.classes._TorchScriptTesting._Foo(self.foo_mod.info(), 1)
|
||||
torchbind_model = torch.classes._TorchScriptTesting._Foo(
|
||||
self.foo_mod.info(), 1
|
||||
)
|
||||
return FooBar(torchbind_model)
|
||||
|
||||
inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
|
||||
@ -338,7 +349,7 @@ class TestTorchbind(JitTestCase):
|
||||
self.assertEqual(torch.zeros(4, 4), traced())
|
||||
|
||||
def test_torchbind_pass_wrong_type(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'but instead found type \'Tensor\''):
|
||||
with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
|
||||
torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
|
||||
|
||||
def test_torchbind_tracing_nested(self):
|
||||
@ -368,12 +379,15 @@ class TestTorchbind(JitTestCase):
|
||||
self.assertEqual(nt_loaded.pop(), exp)
|
||||
|
||||
def test_torchbind_instantiate_missing_class(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Tried to instantiate class 'foo.IDontExist', but it does not exist!",
|
||||
):
|
||||
torch.classes.foo.IDontExist(3, 4, 5)
|
||||
|
||||
def test_torchbind_optional_explicit_attr(self):
|
||||
class TorchBindOptionalExplicitAttr(torch.nn.Module):
|
||||
foo : Optional[torch.classes._TorchScriptTesting._StackString]
|
||||
foo: Optional[torch.classes._TorchScriptTesting._StackString]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -384,13 +398,13 @@ class TestTorchbind(JitTestCase):
|
||||
if foo_obj is not None:
|
||||
return foo_obj.pop()
|
||||
else:
|
||||
return '<None>'
|
||||
return "<None>"
|
||||
|
||||
mod = TorchBindOptionalExplicitAttr()
|
||||
scripted = torch.jit.script(mod)
|
||||
|
||||
def test_torchbind_no_init(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'torch::init'):
|
||||
with self.assertRaisesRegex(RuntimeError, "torch::init"):
|
||||
x = torch.classes._TorchScriptTesting._NoInit()
|
||||
|
||||
def test_profiler_custom_op(self):
|
||||
@ -401,17 +415,17 @@ class TestTorchbind(JitTestCase):
|
||||
|
||||
found_event = False
|
||||
for e in prof.function_events:
|
||||
if e.name == '_TorchScriptTesting::take_an_instance':
|
||||
if e.name == "_TorchScriptTesting::take_an_instance":
|
||||
found_event = True
|
||||
self.assertTrue(found_event)
|
||||
|
||||
def test_torchbind_getattr(self):
|
||||
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||
self.assertEqual(None, getattr(foo, 'bar', None))
|
||||
self.assertEqual(None, getattr(foo, "bar", None))
|
||||
|
||||
def test_torchbind_attr_exception(self):
|
||||
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
|
||||
with self.assertRaisesRegex(AttributeError, "does not have a field"):
|
||||
foo.bar
|
||||
|
||||
def test_lambda_as_constructor(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,21 +1,24 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestTypeSharing(JitTestCase):
|
||||
def assertSameType(self, m1, m2):
|
||||
@ -42,6 +45,7 @@ class TestTypeSharing(JitTestCase):
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
a = torch.rand(2, 3)
|
||||
b = torch.rand(2, 3)
|
||||
c = torch.rand(2, 3)
|
||||
@ -53,6 +57,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Types should be shared even if attribute values differ
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, a, b, c):
|
||||
super().__init__()
|
||||
@ -62,6 +67,7 @@ class TestTypeSharing(JitTestCase):
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
a = torch.rand(2, 3)
|
||||
b = torch.rand(2, 3)
|
||||
c = torch.rand(2, 3)
|
||||
@ -73,6 +79,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Types should be shared for identical constant values, and different for different constant values
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
__constants__ = ["const"]
|
||||
|
||||
@ -111,6 +118,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
If submodules differ, the types should differ.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, in1, out1, in2, out2):
|
||||
super().__init__()
|
||||
@ -137,6 +145,7 @@ class TestTypeSharing(JitTestCase):
|
||||
The same module with an `foo` as a parameter vs. attribute shouldn't
|
||||
share types
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, foo):
|
||||
super().__init__()
|
||||
@ -156,6 +165,7 @@ class TestTypeSharing(JitTestCase):
|
||||
Even if everything about the module is the same, different originating
|
||||
classes should prevent type sharing.
|
||||
"""
|
||||
|
||||
class A(torch.nn.Module):
|
||||
__constants__ = ["const"]
|
||||
|
||||
@ -192,6 +202,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Mutating the value of an attribute should not change type sharing
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, in1, out1, in2, out2):
|
||||
super().__init__()
|
||||
@ -214,6 +225,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Assigning a new (python-only) attribute should not change type sharing
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, in1, out1, in2, out2):
|
||||
super().__init__()
|
||||
@ -244,6 +256,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Attributes whose type cannot be inferred should fail cleanly with nice hints
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -255,15 +268,16 @@ class TestTypeSharing(JitTestCase):
|
||||
return self.foo
|
||||
|
||||
m = M()
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError,
|
||||
"failed to convert Python type",
|
||||
"self.foo"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, "failed to convert Python type", "self.foo"
|
||||
):
|
||||
torch.jit.script(m)
|
||||
|
||||
def test_script_function_attribute_different(self):
|
||||
"""
|
||||
Different functions passed in should lead to different types
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn1(x):
|
||||
return x + x
|
||||
@ -317,6 +331,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Same functions passed in should lead to same types
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
return x + x
|
||||
@ -338,6 +353,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Different functions passed in should lead to different types
|
||||
"""
|
||||
|
||||
def fn1(x):
|
||||
return x + x
|
||||
|
||||
@ -361,6 +377,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Same functions passed in should lead to same types
|
||||
"""
|
||||
|
||||
def fn(x):
|
||||
return x + x
|
||||
|
||||
@ -383,6 +400,7 @@ class TestTypeSharing(JitTestCase):
|
||||
Since we can't guarantee that methods are the same between different
|
||||
trace runs, tracing must always generate a unique type.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
if x.sum() > y.sum():
|
||||
@ -429,8 +447,8 @@ class TestTypeSharing(JitTestCase):
|
||||
def forward(self, x):
|
||||
return self.traced(x)
|
||||
|
||||
a = M((torch.ones(1), ))
|
||||
b = M((torch.zeros(1), ))
|
||||
a = M((torch.ones(1),))
|
||||
b = M((torch.zeros(1),))
|
||||
self.assertDifferentType(a, b)
|
||||
|
||||
def test_loaded_modules_work(self):
|
||||
@ -465,7 +483,6 @@ class TestTypeSharing(JitTestCase):
|
||||
buffer.seek(0)
|
||||
return torch.jit.script(Wrapper(torch.jit.load(buffer)))
|
||||
|
||||
|
||||
a = package(AB())
|
||||
a()
|
||||
b = package(A())
|
||||
@ -476,6 +493,7 @@ class TestTypeSharing(JitTestCase):
|
||||
We should be able to differentiate between two ModuleDict instances
|
||||
that have different keys but the same value types.
|
||||
"""
|
||||
|
||||
class A(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
@ -488,9 +506,9 @@ class TestTypeSharing(JitTestCase):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
a = Foo({'foo': A()})
|
||||
b = Foo({'bar': A()})
|
||||
c = Foo({'bar': A()})
|
||||
a = Foo({"foo": A()})
|
||||
b = Foo({"bar": A()})
|
||||
c = Foo({"bar": A()})
|
||||
self.assertDifferentType(a, b)
|
||||
self.assertSameType(b, c)
|
||||
|
||||
@ -500,13 +518,16 @@ class TestTypeSharing(JitTestCase):
|
||||
subclass that defines methods in its __init__ are not
|
||||
shared.
|
||||
"""
|
||||
|
||||
class A(torch.jit.ScriptModule):
|
||||
def __init__(self, val):
|
||||
super().__init__()
|
||||
self.define(f"""
|
||||
self.define(
|
||||
f"""
|
||||
def forward(self) -> int:
|
||||
return {val}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
one = A(1)
|
||||
two = A(2)
|
||||
@ -518,6 +539,7 @@ class TestTypeSharing(JitTestCase):
|
||||
"""
|
||||
Test that type sharing can be disabled.
|
||||
"""
|
||||
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self, sub):
|
||||
super().__init__()
|
||||
@ -555,6 +577,7 @@ class TestTypeSharing(JitTestCase):
|
||||
Test that types are shared if the exclusion of their
|
||||
ignored attributes makes them equal.
|
||||
"""
|
||||
|
||||
class A(torch.nn.Module):
|
||||
__jit_ignored_attributes__ = ["a"]
|
||||
|
||||
@ -579,6 +602,7 @@ class TestTypeSharing(JitTestCase):
|
||||
Test that types are not shared if the exclusion of their
|
||||
ignored attributes makes them not equal.
|
||||
"""
|
||||
|
||||
class A(torch.nn.Module):
|
||||
__jit_ignored_attributes__ = ["a"]
|
||||
|
||||
|
||||
@ -1,26 +1,31 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing import FileCheck
|
||||
from textwrap import dedent
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.jit_utils
|
||||
from torch.testing import FileCheck
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
from jit.test_module_interface import TestModuleInterface # noqa: F401
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTypesAndAnnotation(JitTestCase):
|
||||
def test_pep585_type(self):
|
||||
@ -30,7 +35,7 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
xl: list[tuple[torch.Tensor]] = []
|
||||
xd: dict[str, int] = {}
|
||||
xl.append((x,))
|
||||
xd['foo'] = 1
|
||||
xd["foo"] = 1
|
||||
return xl.pop(), xd
|
||||
|
||||
self.checkScript(fn, [torch.randn(2, 2)])
|
||||
@ -47,7 +52,7 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
self.checkScript(fn, [torch.randn(2, 2)])
|
||||
|
||||
GG = namedtuple('GG', ['f', 'g'])
|
||||
GG = namedtuple("GG", ["f", "g"])
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
@torch.jit.ignore
|
||||
@ -77,13 +82,17 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
return x + 10
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
|
||||
def forward(
|
||||
self, in_batch: Dict[str, Optional[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
self.dropout_modality(in_batch)
|
||||
fn(in_batch)
|
||||
return torch.tensor(1)
|
||||
|
||||
@torch.jit.ignore
|
||||
def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
|
||||
def dropout_modality(
|
||||
self, in_batch: Dict[str, Optional[torch.Tensor]]
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
return in_batch
|
||||
|
||||
sm = torch.jit.script(M())
|
||||
@ -111,16 +120,17 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
return my_arg + 10
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
|
||||
|
||||
@torch.jit.script
|
||||
def other_fn(x):
|
||||
return fn('2')
|
||||
return fn("2")
|
||||
|
||||
def test_type_annotate_py3(self):
|
||||
def fn():
|
||||
a : List[int] = []
|
||||
b : torch.Tensor = torch.ones(2, 2)
|
||||
c : Optional[torch.Tensor] = None
|
||||
d : Optional[torch.Tensor] = torch.ones(3, 4)
|
||||
a: List[int] = []
|
||||
b: torch.Tensor = torch.ones(2, 2)
|
||||
c: Optional[torch.Tensor] = None
|
||||
d: Optional[torch.Tensor] = torch.ones(3, 4)
|
||||
for _ in range(10):
|
||||
a.append(4)
|
||||
c = torch.ones(2, 2)
|
||||
@ -130,66 +140,88 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def wrong_type():
|
||||
wrong : List[int] = [0.5]
|
||||
wrong: List[int] = [0.5]
|
||||
return wrong
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements",
|
||||
):
|
||||
torch.jit.script(wrong_type)
|
||||
|
||||
def test_optional_no_element_type_annotation(self):
|
||||
"""
|
||||
Test that using an optional with no contained types produces an error.
|
||||
"""
|
||||
|
||||
def fn_with_comment(x: torch.Tensor) -> Optional:
|
||||
return (x, x)
|
||||
|
||||
def annotated_fn(x: torch.Tensor) -> Optional:
|
||||
return (x, x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Optional without a contained type"
|
||||
):
|
||||
cu = torch.jit.CompilationUnit()
|
||||
cu.define(dedent(inspect.getsource(fn_with_comment)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Optional without a contained type"
|
||||
):
|
||||
cu = torch.jit.CompilationUnit()
|
||||
cu.define(dedent(inspect.getsource(annotated_fn)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Optional without a contained type"
|
||||
):
|
||||
torch.jit.script(fn_with_comment)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Optional without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Optional without a contained type"
|
||||
):
|
||||
torch.jit.script(annotated_fn)
|
||||
|
||||
def test_tuple_no_element_type_annotation(self):
|
||||
"""
|
||||
Test that using a tuple with no contained types produces an error.
|
||||
"""
|
||||
|
||||
def fn_with_comment(x: torch.Tensor) -> Tuple:
|
||||
return (x, x)
|
||||
|
||||
def annotated_fn(x: torch.Tensor) -> Tuple:
|
||||
return (x, x)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Tuple without a contained type"
|
||||
):
|
||||
cu = torch.jit.CompilationUnit()
|
||||
cu.define(dedent(inspect.getsource(fn_with_comment)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Tuple without a contained type"
|
||||
):
|
||||
cu = torch.jit.CompilationUnit()
|
||||
cu.define(dedent(inspect.getsource(annotated_fn)))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Tuple without a contained type"
|
||||
):
|
||||
torch.jit.script(fn_with_comment)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"Attempted to use Tuple without a contained type"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Attempted to use Tuple without a contained type"
|
||||
):
|
||||
torch.jit.script(annotated_fn)
|
||||
|
||||
def test_ignoring_module_attributes(self):
|
||||
"""
|
||||
Test that module attributes can be ignored.
|
||||
"""
|
||||
|
||||
class Sub(torch.nn.Module):
|
||||
def forward(self, a: int) -> int:
|
||||
return sum([a])
|
||||
@ -229,10 +261,11 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
mod = ModuleUsesIgnoredAttr(1)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"attribute was ignored during compilation", "self.sub"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"attribute was ignored during compilation", "self.sub"
|
||||
):
|
||||
scripted_mod = torch.jit.script(mod)
|
||||
|
||||
|
||||
def test_ignoring_fn_with_nonscriptable_types(self):
|
||||
class CFX:
|
||||
def __init__(self, a: List[torch.Tensor]) -> None:
|
||||
@ -246,7 +279,9 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
return iter(self.a)
|
||||
|
||||
@torch.jit._drop
|
||||
def __fx_create_arg__(self, tracer: torch.fx.Tracer) -> torch.fx.node.Argument:
|
||||
def __fx_create_arg__(
|
||||
self, tracer: torch.fx.Tracer
|
||||
) -> torch.fx.node.Argument:
|
||||
# torch.fx classes are not scriptable
|
||||
return tracer.create_node(
|
||||
"call_function",
|
||||
@ -257,35 +292,36 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
torch.jit.script(CFX)
|
||||
|
||||
|
||||
def test_unimported_type_resolution(self):
|
||||
# verify fallback from the python resolver to the c++ resolver
|
||||
|
||||
@ torch.jit.script
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
# type: (number) -> number
|
||||
return x + 1
|
||||
|
||||
FileCheck().check('Scalar').run(fn.graph)
|
||||
FileCheck().check("Scalar").run(fn.graph)
|
||||
|
||||
def test_parser_bug(self):
|
||||
def parser_bug(o: Optional[torch.Tensor]):
|
||||
pass
|
||||
|
||||
def test_mismatched_annotation(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
|
||||
with self.assertRaisesRegex(RuntimeError, "annotated with type"):
|
||||
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
x : str = 4
|
||||
x: str = 4
|
||||
return x
|
||||
|
||||
def test_reannotate(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
|
||||
with self.assertRaisesRegex(RuntimeError, "declare and annotate"):
|
||||
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
x = 5
|
||||
if 1 == 1:
|
||||
x : Optional[int] = 7
|
||||
x: Optional[int] = 7
|
||||
|
||||
def test_annotate_outside_init(self):
|
||||
msg = "annotations on instance attributes must be declared in __init__"
|
||||
@ -293,6 +329,7 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
# Simple case
|
||||
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
|
||||
|
||||
@torch.jit.script
|
||||
class BadModule:
|
||||
def __init__(self, x: int):
|
||||
@ -303,6 +340,7 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
|
||||
# Type annotation in a loop
|
||||
with self.assertRaisesRegexWithHighlight(ValueError, msg, highlight):
|
||||
|
||||
@torch.jit.script
|
||||
class BadModuleLoop:
|
||||
def __init__(self, x: int):
|
||||
@ -324,8 +362,10 @@ class TestTypesAndAnnotation(JitTestCase):
|
||||
def test_inferred_type_error_message(self):
|
||||
inferred_type = torch._C.InferredType("ErrorReason")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Tried to get the type from an InferredType but the type is null."):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Tried to get the type from an InferredType but the type is null.",
|
||||
):
|
||||
t = inferred_type.type()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "ErrorReason"):
|
||||
|
||||
@ -2,12 +2,12 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from collections import namedtuple
|
||||
from typing import List, Tuple, Dict, NamedTuple
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
@ -20,14 +20,15 @@ if __name__ == "__main__":
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestTyping(JitTestCase):
|
||||
def test_dict_in_not_in(self):
|
||||
def test_in_dict(x):
|
||||
# type: (Dict[str, int]) -> bool
|
||||
return 'hi' in x
|
||||
return "hi" in x
|
||||
|
||||
self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},))
|
||||
self.checkScript(test_in_dict, ({'bye': 3},))
|
||||
self.checkScript(test_in_dict, ({"hi": 2, "bye": 3},))
|
||||
self.checkScript(test_in_dict, ({"bye": 3},))
|
||||
|
||||
# Check evaluation order
|
||||
@torch.jit.script
|
||||
@ -57,8 +58,8 @@ class TestTyping(JitTestCase):
|
||||
else:
|
||||
return True
|
||||
|
||||
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, ))
|
||||
self.checkScript(test_not_in_dict, ({"world": 2}, ))
|
||||
self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2},))
|
||||
self.checkScript(test_not_in_dict, ({"world": 2},))
|
||||
|
||||
def test_dict_tensor_key(a, t):
|
||||
# type: (Dict[Tensor, int], Tensor) -> bool
|
||||
@ -80,9 +81,12 @@ class TestTyping(JitTestCase):
|
||||
l: List[int] = [1, 2, "foo", 3]
|
||||
return l
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_type_refinement_annotation_key_mismatch(self):
|
||||
@ -92,10 +96,13 @@ class TestTyping(JitTestCase):
|
||||
d: Dict[int, str] = dict(zip(l1, l2))
|
||||
return d
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dicts may only "
|
||||
"contain homogeneous keys, but the "
|
||||
"type of the first generated key "
|
||||
r"was Union\[int, str\]"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Dicts may only "
|
||||
"contain homogeneous keys, but the "
|
||||
"type of the first generated key "
|
||||
r"was Union\[int, str\]",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_type_refinement_annotation_value_mismatch(self):
|
||||
@ -105,28 +112,36 @@ class TestTyping(JitTestCase):
|
||||
d: Dict[str, int] = dict(zip(l1, l2))
|
||||
return d
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
||||
r" `Dict\[str, int\]` did not match"
|
||||
" the type of an actual value type"
|
||||
r" `Union\[int, str\]`"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Dict type annotation"
|
||||
r" `Dict\[str, int\]` did not match"
|
||||
" the type of an actual value type"
|
||||
r" `Union\[int, str\]`",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_invalid_annotations(self):
|
||||
# Check for invalid value type annotation
|
||||
def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]):
|
||||
return
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
||||
torch.jit.script(wrong_value_type)
|
||||
|
||||
# Check for invalid key type annotation
|
||||
def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]):
|
||||
return
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
||||
torch.jit.script(wrong_key_type)
|
||||
|
||||
# Check for invalid key and value type annotation
|
||||
def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]):
|
||||
def wrong_key_value_type(
|
||||
dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]
|
||||
):
|
||||
return
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unknown type annotation"):
|
||||
torch.jit.script(wrong_key_value_type)
|
||||
|
||||
@ -138,13 +153,16 @@ class TestTyping(JitTestCase):
|
||||
_, y = t2
|
||||
return x + y
|
||||
|
||||
t = torch.randn(2, 2), (1, torch.randn(2, 2)),
|
||||
t = (
|
||||
torch.randn(2, 2),
|
||||
(1, torch.randn(2, 2)),
|
||||
)
|
||||
f(t, "hi")
|
||||
graph = f.graph_for(t, "hi")
|
||||
input_types = list(next(graph.inputs()).type().elements())
|
||||
w = input_types[0]
|
||||
self.assertEqual(input_types[0].kind(), 'TensorType')
|
||||
self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType')
|
||||
self.assertEqual(input_types[0].kind(), "TensorType")
|
||||
self.assertEqual(input_types[1].elements()[1].kind(), "TensorType")
|
||||
|
||||
def test_tuple_io(self):
|
||||
def stuff(x):
|
||||
@ -165,8 +183,7 @@ class TestTyping(JitTestCase):
|
||||
def foo():
|
||||
return tuple(1, 2)
|
||||
|
||||
self.checkScriptRaisesRegex(foo, (), Exception,
|
||||
"1 argument")
|
||||
self.checkScriptRaisesRegex(foo, (), Exception, "1 argument")
|
||||
|
||||
def cant_infer_size():
|
||||
return tuple([1, 2, 3]) # noqa: C409
|
||||
@ -179,12 +196,14 @@ class TestTyping(JitTestCase):
|
||||
# type: (int) -> Tuple[Tensor, Tensor]
|
||||
a = (torch.ones(x), torch.zeros(x))
|
||||
return a
|
||||
|
||||
self.checkScript(stuff2, (3,))
|
||||
|
||||
def test_list_io(self):
|
||||
def stuff3(x):
|
||||
# type: (List[int]) -> Tuple[Tensor, List[int]]
|
||||
return torch.ones(x), x
|
||||
|
||||
self.checkScript(stuff3, ([3, 2],))
|
||||
|
||||
def test_bool_list_io(self):
|
||||
@ -203,6 +222,7 @@ class TestTyping(JitTestCase):
|
||||
# type: (Tuple[int, List[List[int]]]) -> int
|
||||
x, y = z
|
||||
return y[0][1]
|
||||
|
||||
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
|
||||
|
||||
def test_list_sum(self):
|
||||
@ -215,12 +235,12 @@ class TestTyping(JitTestCase):
|
||||
def fn2(x: List[bool]):
|
||||
return sum(x)
|
||||
|
||||
self.checkScript(fn, ([1, 2, 3], ))
|
||||
self.checkScript(fn1, ([1.0, 2.0, 3.0], ))
|
||||
self.checkScript(fn1, ([1, 2.8, 3], ))
|
||||
self.checkScript(fn2, ([True, False, False], ))
|
||||
self.checkScript(fn2, ([False, False, False], ))
|
||||
self.checkScript(fn2, ([0, 1, 1, 0], ))
|
||||
self.checkScript(fn, ([1, 2, 3],))
|
||||
self.checkScript(fn1, ([1.0, 2.0, 3.0],))
|
||||
self.checkScript(fn1, ([1, 2.8, 3],))
|
||||
self.checkScript(fn2, ([True, False, False],))
|
||||
self.checkScript(fn2, ([False, False, False],))
|
||||
self.checkScript(fn2, ([0, 1, 1, 0],))
|
||||
|
||||
def test_list_unification(self):
|
||||
def fn():
|
||||
@ -254,7 +274,6 @@ class TestTyping(JitTestCase):
|
||||
self.checkScript(self.get_sum_list_fn(), ([1],))
|
||||
|
||||
def test_sum_list_literal(self):
|
||||
|
||||
def sum_list():
|
||||
# type: () -> int
|
||||
sum = 0
|
||||
@ -266,8 +285,8 @@ class TestTyping(JitTestCase):
|
||||
self.checkScript(sum_list, ())
|
||||
|
||||
def test_sum_list_wrong_type(self):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
|
||||
|
||||
@torch.jit.script
|
||||
def sum_list(a):
|
||||
# type: (int) -> int
|
||||
@ -280,14 +299,18 @@ class TestTyping(JitTestCase):
|
||||
sum_list(1)
|
||||
|
||||
def test_list_iterables(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "List of iterables is not supported currently"
|
||||
):
|
||||
cu = torch.jit.CompilationUnit(
|
||||
"""
|
||||
def list_iterables(x):
|
||||
for i, j in [2, 3, 4], [5, 6, 7]:
|
||||
x += i
|
||||
x += j
|
||||
return x
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
def test_for_in_string(self):
|
||||
def test_strings(x):
|
||||
@ -352,36 +375,43 @@ class TestTyping(JitTestCase):
|
||||
|
||||
def test_dict_comprehension(self):
|
||||
def fn():
|
||||
return {i : chr(i + 65) for i in range(4)}
|
||||
return {i: chr(i + 65) for i in range(4)}
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_dict_comprehension_with_type_annotation(self):
|
||||
def fn():
|
||||
d: Dict[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||
d: Dict[int, str] = {i: chr(i + 65) for i in range(4)}
|
||||
return d
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, ""):
|
||||
with self.assertRaisesRegex(AssertionError, "Expected Dict "
|
||||
"type annotation for dict "
|
||||
"comprehension, found "
|
||||
"Tuple[int, str]"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
"Expected Dict "
|
||||
"type annotation for dict "
|
||||
"comprehension, found "
|
||||
"Tuple[int, str]",
|
||||
):
|
||||
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)}
|
||||
d: Tuple[int, str] = {i: chr(i + 65) for i in range(4)}
|
||||
return d
|
||||
|
||||
def test_dict_comprehension_scope(self):
|
||||
def comprehension_can_access_outer_scope_variables():
|
||||
lst = ["foo", "bar", "baz"]
|
||||
return {l : len(l) for l in lst}
|
||||
return {l: len(l) for l in lst}
|
||||
|
||||
self.checkScript(comprehension_can_access_outer_scope_variables, ())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "undefined value i"):
|
||||
|
||||
@torch.jit.script
|
||||
def outer_scope_cannot_access_comprehension_variables():
|
||||
d = {i : chr(i + 65) for i in range(4)}
|
||||
d = {i: chr(i + 65) for i in range(4)}
|
||||
i = i + 1 # noqa: F821
|
||||
|
||||
def test_for_tuple_assign(self):
|
||||
@ -402,22 +432,28 @@ class TestTyping(JitTestCase):
|
||||
sum += a[1]
|
||||
return sum
|
||||
|
||||
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), ))
|
||||
self.checkScript(test_tuple_assign, (((1, 2), (4, 7)),))
|
||||
|
||||
def test_single_starred_lhs(self):
|
||||
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
|
||||
' of another non-starred expression'):
|
||||
cu = torch.jit.CompilationUnit('''
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"A Starred expression may only appear on the lhs within the presence"
|
||||
" of another non-starred expression",
|
||||
):
|
||||
cu = torch.jit.CompilationUnit(
|
||||
"""
|
||||
def single_starred_lhs(x):
|
||||
a = (x, x, x)
|
||||
*b, = a
|
||||
return b
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
def test_singleton_tuple_unpack(self):
|
||||
def foo(a):
|
||||
b, = (a,)
|
||||
(b,) = (a,)
|
||||
return b + 1
|
||||
|
||||
self.checkScript(foo, (torch.rand(3),))
|
||||
|
||||
def test_tuple_assignments(self):
|
||||
@ -441,7 +477,9 @@ class TestTyping(JitTestCase):
|
||||
a[i], (x[i], b) = 1, (2, 3)
|
||||
return a[i] + 1, x + 5, b
|
||||
|
||||
self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0))
|
||||
self.checkScript(
|
||||
subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)
|
||||
)
|
||||
|
||||
def star_tuple_assign():
|
||||
# type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]]
|
||||
@ -455,7 +493,7 @@ class TestTyping(JitTestCase):
|
||||
a[0] += 1
|
||||
return a
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'):
|
||||
with self.assertRaisesRegex(RuntimeError, "does not support augmented assign"):
|
||||
scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign)
|
||||
|
||||
def test_multiple_assign(self):
|
||||
@ -505,7 +543,6 @@ class TestTyping(JitTestCase):
|
||||
# type: (Optional[int]) -> int
|
||||
return torch.jit._unwrap_optional(x)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fn(x):
|
||||
# type: (int) -> int
|
||||
@ -540,7 +577,7 @@ class TestTyping(JitTestCase):
|
||||
# type: (Tuple[float, float]) -> int
|
||||
return opt_list(x) + broadcast_opt_list(x)
|
||||
|
||||
self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
|
||||
self.assertEqual(opt_list_tuple_caller((2.0, 3.0)), 4)
|
||||
|
||||
def test_optional_tuple(self):
|
||||
def fn(x=None):
|
||||
@ -556,10 +593,11 @@ class TestTyping(JitTestCase):
|
||||
|
||||
def test_namedtuple_redefine(self):
|
||||
global _1, _2
|
||||
_1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
||||
_2 = namedtuple('GoogLeNetOutputs', ['different'])
|
||||
_1 = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
|
||||
_2 = namedtuple("GoogLeNetOutputs", ["different"])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"redefine"):
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'redefine'):
|
||||
@torch.jit.script
|
||||
def foo(x, y):
|
||||
# type: (_1, _2) -> _1
|
||||
@ -567,7 +605,9 @@ class TestTyping(JitTestCase):
|
||||
|
||||
def test_namedtuple_py2(self):
|
||||
global _GoogLeNetOutputs # see [local resolution in python]
|
||||
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
||||
_GoogLeNetOutputs = namedtuple(
|
||||
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x):
|
||||
@ -575,22 +615,27 @@ class TestTyping(JitTestCase):
|
||||
return x
|
||||
|
||||
vals = torch.rand(3), torch.rand(4), torch.rand(5)
|
||||
out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2]))
|
||||
out = foo(
|
||||
_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])
|
||||
)
|
||||
self.assertEqual(out.logits, vals[0])
|
||||
self.assertEqual(out.aux_logits2, vals[1])
|
||||
self.assertEqual(out.aux_logits1, vals[2])
|
||||
|
||||
def test_namedtuple_good_error(self):
|
||||
global _GoogLeNetOutputs # see [local resolution in python]
|
||||
_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
|
||||
_GoogLeNetOutputs = namedtuple(
|
||||
"GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]
|
||||
)
|
||||
|
||||
@torch.jit.script
|
||||
def foo(x):
|
||||
# type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs
|
||||
return x
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"aka NamedTuple\(logits, aux_logits2, aux_logits1\)"
|
||||
):
|
||||
out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5"))
|
||||
|
||||
def test_namedtuple_error_source_attribution(self):
|
||||
|
||||
@ -3,22 +3,25 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from enum import Enum
|
||||
from textwrap import dedent
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUnion(JitTestCase):
|
||||
"""
|
||||
@ -57,9 +60,12 @@ class TestUnion(JitTestCase):
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[float, int\] but "
|
||||
"instead found type str"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[float, int\] but "
|
||||
"instead found type str",
|
||||
):
|
||||
scripted("1")
|
||||
|
||||
def test_union_with_collections(self):
|
||||
@ -71,22 +77,31 @@ class TestUnion(JitTestCase):
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"Dict\[str, str\]"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"Dict\[str, str\]",
|
||||
):
|
||||
scripted({"foo": "bar", "baz": "qux"})
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"List\[str\]"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
r"List\[str\]",
|
||||
):
|
||||
scripted(["foo", "bar", "baz"])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
"str"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[List\[int\], Dict\[str, "
|
||||
r"int\]\] but instead found type "
|
||||
"str",
|
||||
):
|
||||
scripted("1")
|
||||
|
||||
def test_union_with_enum(self):
|
||||
@ -104,16 +119,18 @@ class TestUnion(JitTestCase):
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[__torch__.jit.test_union."
|
||||
r"Color, str\] but instead found "
|
||||
"type int"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[__torch__.jit.test_union."
|
||||
r"Color, str\] but instead found "
|
||||
"type int",
|
||||
):
|
||||
scripted(1)
|
||||
|
||||
def test_union_in_class_constructor(self):
|
||||
|
||||
@torch.jit.script # noqa: B903
|
||||
class A: # noqa: B903
|
||||
class A: # noqa: B903
|
||||
def __init__(self, x: Union[int, str]) -> None:
|
||||
self.x = x
|
||||
|
||||
@ -125,9 +142,12 @@ class TestUnion(JitTestCase):
|
||||
|
||||
scripted = torch.jit.script(fn)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected a member of"
|
||||
r" Union\[int, str\] but instead "
|
||||
r"found type List\[str\]"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected a member of"
|
||||
r" Union\[int, str\] but instead "
|
||||
r"found type List\[str\]",
|
||||
):
|
||||
scripted(["foo", "bar", "baz"])
|
||||
|
||||
def test_union_return_type(self):
|
||||
@ -171,7 +191,7 @@ class TestUnion(JitTestCase):
|
||||
def test_union_variable_can_be_reassigned(self):
|
||||
@torch.jit.script
|
||||
def aux1(i: int):
|
||||
return int(i ** 2)
|
||||
return int(i**2)
|
||||
|
||||
@torch.jit.script
|
||||
def aux2(s: str):
|
||||
@ -225,8 +245,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union(float, int, str)") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(float, int, str)").run(s)
|
||||
|
||||
def test_unions_of_a_single_argument_vanish(self):
|
||||
@torch.jit.script
|
||||
@ -235,8 +254,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : int") \
|
||||
.run(s)
|
||||
FileCheck().check("x : int").run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped(self):
|
||||
@torch.jit.script
|
||||
@ -245,8 +263,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union(int, str)") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(int, str)").run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_optional(self):
|
||||
@torch.jit.script
|
||||
@ -255,8 +272,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union(float, int, NoneType)") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(float, int, NoneType)").run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_subtyping(self):
|
||||
@torch.jit.script
|
||||
@ -265,8 +281,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union((int?, int), str)") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union((int?, int), str)").run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_container(self):
|
||||
@torch.jit.script
|
||||
@ -275,8 +290,7 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union(float[], str[])") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(float[], str[])").run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored(self):
|
||||
@torch.jit.script
|
||||
@ -288,8 +302,7 @@ class TestUnion(JitTestCase):
|
||||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union(int, str)") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(int, str)").run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored_container(self):
|
||||
@torch.jit.script
|
||||
@ -301,8 +314,7 @@ class TestUnion(JitTestCase):
|
||||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union(int[], str[])") \
|
||||
.run(s)
|
||||
FileCheck().check("x : Union(int[], str[])").run(s)
|
||||
|
||||
def test_union_T_None_is_equivalent_to_optional_T(self):
|
||||
@torch.jit.script
|
||||
@ -366,9 +378,9 @@ class TestUnion(JitTestCase):
|
||||
|
||||
s = l.code
|
||||
|
||||
FileCheck().check("Union[int, NoneType, str]") \
|
||||
.check("Union[int, NoneType, str]") \
|
||||
.run(s)
|
||||
FileCheck().check("Union[int, NoneType, str]").check(
|
||||
"Union[int, NoneType, str]"
|
||||
).run(s)
|
||||
|
||||
def test_union_subclasses_larger_union(self):
|
||||
def fn() -> Union[int, str, torch.Tensor]:
|
||||
@ -386,9 +398,12 @@ class TestUnion(JitTestCase):
|
||||
x[1] = 2
|
||||
return x[1]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "only int, float, "
|
||||
"complex, Tensor, device and string keys "
|
||||
"are supported"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"only int, float, "
|
||||
"complex, Tensor, device and string keys "
|
||||
"are supported",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_as_dict_value(self):
|
||||
@ -402,7 +417,6 @@ class TestUnion(JitTestCase):
|
||||
|
||||
def test_union_module_with_union_instance_variable(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
x: Union[int, str]
|
||||
|
||||
def __init__(self, x: Union[int, str]):
|
||||
@ -413,7 +427,12 @@ class TestUnion(JitTestCase):
|
||||
self.x = y
|
||||
return self.x
|
||||
|
||||
self.checkModule(M(2,), (1,))
|
||||
self.checkModule(
|
||||
M(
|
||||
2,
|
||||
),
|
||||
(1,),
|
||||
)
|
||||
self.checkModule(M("bar"), ("foo",))
|
||||
|
||||
def test_union_module_with_union_class_variable(self):
|
||||
@ -508,9 +527,7 @@ class TestUnion(JitTestCase):
|
||||
s = fn.graph
|
||||
|
||||
# Check that we don't have any branching statements
|
||||
FileCheck().check_not("block0()") \
|
||||
.check_not("block1()") \
|
||||
.run(s)
|
||||
FileCheck().check_not("block0()").check_not("block1()").run(s)
|
||||
|
||||
def test_union_type_refinement_statically_true(self):
|
||||
@torch.jit.script
|
||||
@ -525,9 +542,7 @@ class TestUnion(JitTestCase):
|
||||
s = fn.graph
|
||||
|
||||
# Check that we don't have any branching statements
|
||||
FileCheck().check_not("block0()") \
|
||||
.check_not("block1()") \
|
||||
.run(s)
|
||||
FileCheck().check_not("block0()").check_not("block1()").run(s)
|
||||
|
||||
def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
|
||||
def fn(x: Union[List[int], int]) -> int:
|
||||
@ -556,7 +571,7 @@ class TestUnion(JitTestCase):
|
||||
def test_union_type_refinement_internal_declaration(self):
|
||||
def fn(flag: bool) -> str:
|
||||
x: Union[int, str, None] = None
|
||||
if (flag):
|
||||
if flag:
|
||||
y = "foo"
|
||||
else:
|
||||
y = 1
|
||||
@ -589,9 +604,12 @@ class TestUnion(JitTestCase):
|
||||
else:
|
||||
return "bar"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "y is set to type str"
|
||||
" in the true branch and type int "
|
||||
"in the false branch"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"y is set to type str"
|
||||
" in the true branch and type int "
|
||||
"in the false branch",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_branching_does_not_widen_existing_inferred_type(self):
|
||||
@ -606,9 +624,12 @@ class TestUnion(JitTestCase):
|
||||
else:
|
||||
return "baz"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "previously had type "
|
||||
"str but is now being assigned to a"
|
||||
" value of type int"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"previously had type "
|
||||
"str but is now being assigned to a"
|
||||
" value of type int",
|
||||
):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_union_schema_matching_on_internal_type(self):
|
||||
@ -645,8 +666,8 @@ class TestUnion(JitTestCase):
|
||||
|
||||
def test_union_memory_aliasing(self):
|
||||
def fn():
|
||||
x : List[torch.Tensor] = []
|
||||
z : List[Optional[List[torch.Tensor]]] = []
|
||||
x: List[torch.Tensor] = []
|
||||
z: List[Optional[List[torch.Tensor]]] = []
|
||||
z.append(x)
|
||||
x_alias = z[0]
|
||||
if torch.jit.isinstance(x_alias, List[torch.Tensor]):
|
||||
@ -682,203 +703,212 @@ class TestUnion(JitTestCase):
|
||||
code = template.format(ann=ann, lhs=lhs)
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
cu = torch.jit.CompilationUnit(code, _frames_up=1)
|
||||
string_frontend = getattr(cu, "fn") # noqa: B009
|
||||
string_frontend = getattr(cu, "fn") # noqa: B009
|
||||
|
||||
def test_union_with_list_assignment(self):
|
||||
template = dedent('''
|
||||
template = dedent(
|
||||
"""
|
||||
def fn():
|
||||
x: {ann} = {lhs}
|
||||
if torch.jit.isinstance(x, List[torch.Tensor]):
|
||||
x.append(torch.tensor(3))
|
||||
return x
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
lhs = {"list_literal_empty" : "[]",
|
||||
|
||||
"list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",
|
||||
|
||||
"list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",
|
||||
|
||||
"list_literal_of_mixed" : "[torch.arange(5), 1]",
|
||||
|
||||
"list_comprehension_of_tensor" :
|
||||
"[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
|
||||
|
||||
"list_comprehension_of_str" :
|
||||
"[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",
|
||||
|
||||
"list_comprehension_of_mixed" :
|
||||
"[torch.add(1, x) for x in [torch.arange(5), 1]]"}
|
||||
lhs = {
|
||||
"list_literal_empty": "[]",
|
||||
"list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
|
||||
"list_literal_of_str": '["foo", "bar", "baz"]',
|
||||
"list_literal_of_mixed": "[torch.arange(5), 1]",
|
||||
"list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
|
||||
"list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
|
||||
"list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
|
||||
}
|
||||
|
||||
"""
|
||||
Union[List[str], List[torch.Tensor]]
|
||||
"""
|
||||
self._assert_raises(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_empty"],
|
||||
"there are multiple possible List type "
|
||||
"candidates in the Union annotation")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_empty"],
|
||||
"there are multiple possible List type "
|
||||
"candidates in the Union annotation",
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_tensor"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_tensor"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_str"])
|
||||
self._assert_passes(
|
||||
template, "Union[List[str], List[torch.Tensor]]", lhs["list_literal_of_str"]
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
"none of those types match the types of the"
|
||||
" given list elements")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
"none of those types match the types of the" " given list elements",
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_tensor"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_str"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_str"],
|
||||
)
|
||||
|
||||
# TODO: Support mixed list comprehensions
|
||||
self._assert_raises(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_mixed"],
|
||||
"Arguments for call are not valid")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["list_comprehension_of_mixed"],
|
||||
"Arguments for call are not valid",
|
||||
)
|
||||
|
||||
"""
|
||||
Union[int, torch.Tensor]
|
||||
"""
|
||||
self._assert_raises(template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_literal_empty"],
|
||||
"Expected an Union type annotation with an "
|
||||
"inner List type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(template, "Union[int, torch.Tensor]",
|
||||
lhs["list_literal_of_tensor"],
|
||||
"Expected an Union type annotation with an "
|
||||
"inner List type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_literal_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
)
|
||||
|
||||
self._assert_raises(template, "Union[int, torch.Tensor]",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
"Expected an Union type annotation with an "
|
||||
"inner List type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
"Expected an Union type annotation with an " "inner List type",
|
||||
)
|
||||
|
||||
"""
|
||||
Union[List[torch.Tensor], int]
|
||||
"""
|
||||
self._assert_passes(template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_empty"])
|
||||
self._assert_passes(
|
||||
template, "Union[List[torch.Tensor], int]", lhs["list_literal_empty"]
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_of_tensor"])
|
||||
self._assert_passes(
|
||||
template, "Union[List[torch.Tensor], int]", lhs["list_literal_of_tensor"]
|
||||
)
|
||||
|
||||
self._assert_raises(template, "Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_of_str"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_of_str"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements",
|
||||
)
|
||||
|
||||
self._assert_raises(template, "Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_literal_of_mixed"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements",
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_tensor"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_tensor"],
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_str"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_str"],
|
||||
r"List type annotation `List\[Tensor\]` did "
|
||||
"not match the types of the given list "
|
||||
"elements",
|
||||
)
|
||||
|
||||
# TODO(@ansley): Support mixed list comprehensions
|
||||
self._assert_raises(template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_mixed"],
|
||||
"Arguments for call are not valid")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[torch.Tensor], int]",
|
||||
lhs["list_comprehension_of_mixed"],
|
||||
"Arguments for call are not valid",
|
||||
)
|
||||
|
||||
def test_union_with_dict_assignment(self):
|
||||
template = dedent('''
|
||||
template = dedent(
|
||||
"""
|
||||
def fn():
|
||||
x: {ann} = {lhs}
|
||||
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
|
||||
x["foo"] = torch.tensor(3)
|
||||
return x
|
||||
''')
|
||||
"""
|
||||
)
|
||||
|
||||
lhs = {"dict_literal_empty" : "{}",
|
||||
|
||||
"dict_literal_of_str_tensor" :
|
||||
"{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",
|
||||
|
||||
"dict_literal_of_str_int" :
|
||||
"{\"foo\" : 1, \"bar\" : 2}",
|
||||
|
||||
"dict_literal_of_mixed" :
|
||||
"{\"foo\" : torch.arange(3), \"bar\" : 2}",
|
||||
|
||||
"dict_comprehension_of_str_tensor" :
|
||||
"{x : torch.add(y, 1) for x, y in \
|
||||
zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",
|
||||
|
||||
"dict_comprehension_of_str_int" :
|
||||
"{x : torch.add(y, 1) for x, y in \
|
||||
zip([\"foo\", \"bar\"], [1, 2]}",
|
||||
|
||||
"dict_comprehension_of_mixed" :
|
||||
"{x : torch.add(y, 1) for x, y in \
|
||||
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",
|
||||
|
||||
"dict_keyword" :
|
||||
"dict(foo=torch.arange(3), baz=torch.arange(5))",
|
||||
|
||||
"dict_keyword_with_iterable" :
|
||||
"dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])",
|
||||
|
||||
"dict_keyword_with_empty_iterable" :
|
||||
"dict([])",
|
||||
|
||||
"dict_keyword_with_internal_aggregate_function" :
|
||||
"dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])",
|
||||
|
||||
"dict_keyword_with_mapping" :
|
||||
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})",
|
||||
|
||||
"dict_keyword_with_mapping_and_kwargs" :
|
||||
"dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))",
|
||||
|
||||
}
|
||||
lhs = {
|
||||
"dict_literal_empty": "{}",
|
||||
"dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
|
||||
"dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
|
||||
"dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
|
||||
"dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
|
||||
zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
|
||||
"dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
|
||||
zip(["foo", "bar"], [1, 2]}',
|
||||
"dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
|
||||
zip(["foo", "bar"], [torch.arange(3), 2])}',
|
||||
"dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
|
||||
"dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
|
||||
"dict_keyword_with_empty_iterable": "dict([])",
|
||||
"dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
|
||||
"dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
|
||||
"dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
|
||||
}
|
||||
|
||||
"""
|
||||
Union[Dict[str, torch.Tensor], Dict[str, int]]
|
||||
"""
|
||||
self._assert_raises(template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with an "
|
||||
"inner Dict type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[List[str], List[torch.Tensor]]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with an " "inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_str_tensor"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_str_int"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_str_int"],
|
||||
)
|
||||
|
||||
self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"none of those dict types can hold the "
|
||||
"types of the given keys and values")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"none of those dict types can hold the "
|
||||
"types of the given keys and values",
|
||||
)
|
||||
|
||||
# TODO: String frontend does not support tuple unpacking
|
||||
# https://github.com/pytorch/pytorch/issues/64096
|
||||
@ -899,45 +929,57 @@ class TestUnion(JitTestCase):
|
||||
# TODO(@ansley): Follow-up project needed for full type
|
||||
# inference with dict keyword (supported for dict comprehension
|
||||
# and dict literal already; should not be a blocker for anyone)
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword"],
|
||||
"full type inference is not yet supported")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword"],
|
||||
"full type inference is not yet supported",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_iterable"],
|
||||
"full type inference is not yet supported")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_iterable"],
|
||||
"full type inference is not yet supported",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_empty_iterable"],
|
||||
"full type inference is not yet supported")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_empty_iterable"],
|
||||
"full type inference is not yet supported",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_mapping"],
|
||||
"full type inference is not yet supported")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_mapping"],
|
||||
"full type inference is not yet supported",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_mapping_and_kwargs"],
|
||||
"full type inference is not yet supported")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
|
||||
lhs["dict_keyword_with_mapping_and_kwargs"],
|
||||
"full type inference is not yet supported",
|
||||
)
|
||||
|
||||
"""
|
||||
Union[int, torch.Tensor]
|
||||
"""
|
||||
self._assert_raises(template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with "
|
||||
"an inner Dict type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_empty"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
"Expected an Union type annotation with "
|
||||
"an inner Dict type")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[int, torch.Tensor]",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
"Expected an Union type annotation with " "an inner Dict type",
|
||||
)
|
||||
|
||||
# See above--string frontend does not support tuple unpacking
|
||||
# self._assert_raises(template, "Union[int, torch.Tensor]",
|
||||
@ -947,47 +989,61 @@ class TestUnion(JitTestCase):
|
||||
"""
|
||||
Union[Dict[str, torch.Tensor], int]
|
||||
"""
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_empty"])
|
||||
self._assert_passes(
|
||||
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_literal_empty"]
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_str_tensor"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_str_tensor"],
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_str_int"],
|
||||
"Type annotation was inferred to be "
|
||||
r"`Dict\[str, Tensor\]`, but the type of "
|
||||
"values given by the dict literal is")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_str_int"],
|
||||
"Type annotation was inferred to be "
|
||||
r"`Dict\[str, Tensor\]`, but the type of "
|
||||
"values given by the dict literal is",
|
||||
)
|
||||
|
||||
self._assert_raises(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"Type annotation was inferred to be "
|
||||
r"`Dict\[str, Tensor\]`, but the type of "
|
||||
"values given by the dict literal is")
|
||||
self._assert_raises(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_literal_of_mixed"],
|
||||
"Type annotation was inferred to be "
|
||||
r"`Dict\[str, Tensor\]`, but the type of "
|
||||
"values given by the dict literal is",
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword"])
|
||||
self._assert_passes(
|
||||
template, "Union[Dict[str, torch.Tensor], int]", lhs["dict_keyword"]
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_iterable"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_iterable"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_empty_iterable"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_empty_iterable"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_mapping"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_mapping"],
|
||||
)
|
||||
|
||||
self._assert_passes(template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_mapping_and_kwargs"])
|
||||
self._assert_passes(
|
||||
template,
|
||||
"Union[Dict[str, torch.Tensor], int]",
|
||||
lhs["dict_keyword_with_mapping_and_kwargs"],
|
||||
)
|
||||
|
||||
# See above--string frontend does not support tuple unpacking
|
||||
# self._assert_passes(template,
|
||||
|
||||
@ -2,19 +2,21 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
# NOTE: FIXING FAILING TESTS
|
||||
# If you are seeing a test failure from this file, congrats, you improved
|
||||
@ -22,6 +24,7 @@ if __name__ == '__main__':
|
||||
# the corresponding section in documentation that states the unsupported behavior.
|
||||
# see: `jit_unsupported.rst`
|
||||
|
||||
|
||||
class TestUnsupportedOps(JitTestCase):
|
||||
def test_factory_ops_requires_grad_fail(self):
|
||||
# Keyword argument {name} unknown is a JIT-only error message,
|
||||
@ -32,31 +35,31 @@ class TestUnsupportedOps(JitTestCase):
|
||||
def ones():
|
||||
return torch.ones([2], requires_grad=True)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
"Keyword argument requires_grad unknown",
|
||||
"torch.ones"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, "Keyword argument requires_grad unknown", "torch.ones"
|
||||
):
|
||||
torch.jit.script(ones)
|
||||
|
||||
def randn():
|
||||
return torch.randn([2], requires_grad=True)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
"Keyword argument requires_grad unknown",
|
||||
"torch.randn"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, "Keyword argument requires_grad unknown", "torch.randn"
|
||||
):
|
||||
torch.jit.script(randn)
|
||||
|
||||
def zeros():
|
||||
return torch.zeros([2], requires_grad=True)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception,
|
||||
"Keyword argument requires_grad unknown",
|
||||
"torch.zeros"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, "Keyword argument requires_grad unknown", "torch.zeros"
|
||||
):
|
||||
torch.jit.script(zeros)
|
||||
|
||||
@unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")
|
||||
def test_init_ops(self):
|
||||
def calculate_gain():
|
||||
return torch.nn.init.calculate_gain('leaky_relu', 0.2)
|
||||
return torch.nn.init.calculate_gain("leaky_relu", 0.2)
|
||||
|
||||
def eye_():
|
||||
return torch.nn.init.eye_(torch.zeros([2, 2]))
|
||||
@ -71,9 +74,16 @@ class TestUnsupportedOps(JitTestCase):
|
||||
return torch.nn.init.orthogonal_(torch.empty(3, 5))
|
||||
|
||||
def sparse():
|
||||
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=.1)
|
||||
return torch.nn.init.sparse_(torch.empty(3, 5), sparsity=0.1)
|
||||
|
||||
for func in [calculate_gain, eye_, dirac_, kaiming_uniform_, orthogonal_, sparse]:
|
||||
for func in [
|
||||
calculate_gain,
|
||||
eye_,
|
||||
dirac_,
|
||||
kaiming_uniform_,
|
||||
orthogonal_,
|
||||
sparse,
|
||||
]:
|
||||
# doesn't error in eager
|
||||
func()
|
||||
with self.assertRaisesRegex(Exception, ""):
|
||||
|
||||
@ -3,20 +3,24 @@
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import zipfile
|
||||
from torch.testing import FileCheck
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestUpgraders(JitTestCase):
|
||||
def _load_model_version(self, loaded_model):
|
||||
@ -28,10 +32,10 @@ class TestUpgraders(JitTestCase):
|
||||
# in a package between version 3 and 7.
|
||||
# So we have to check for both.
|
||||
try:
|
||||
version = int(zipped_model.read('archive/version').decode("utf-8"))
|
||||
version = int(zipped_model.read("archive/version").decode("utf-8"))
|
||||
return version
|
||||
except KeyError:
|
||||
version = int(zipped_model.read('archive/.data/version').decode("utf-8"))
|
||||
version = int(zipped_model.read("archive/.data/version").decode("utf-8"))
|
||||
return version
|
||||
|
||||
# TODO (tugsuu) We should ideally be generating this test cases.
|
||||
@ -62,15 +66,23 @@ class TestUpgraders(JitTestCase):
|
||||
upgrader_bumped_version = 3
|
||||
upgrader_name = "_test_serialization_subcmul_0_2"
|
||||
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
|
||||
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
|
||||
dummy_entry = torch._C._UpgraderEntry(
|
||||
upgrader_bumped_version, upgrader_name, upgrader_schema
|
||||
)
|
||||
|
||||
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
|
||||
torch._C._test_only_add_entry_to_op_version_map(
|
||||
"aten::_test_serialization_subcmul", dummy_entry
|
||||
)
|
||||
map_after_test = torch._C._get_operator_version_map()
|
||||
self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
|
||||
self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
|
||||
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
|
||||
torch._C._test_only_remove_entry_to_op_version_map(
|
||||
"aten::_test_serialization_subcmul"
|
||||
)
|
||||
map_after_remove_test = torch._C._get_operator_version_map()
|
||||
self.assertTrue("aten::_test_serialization_subcmul" not in map_after_remove_test)
|
||||
self.assertTrue(
|
||||
"aten::_test_serialization_subcmul" not in map_after_remove_test
|
||||
)
|
||||
self.assertEqual(len(map_after_remove_test), len(map_before_test))
|
||||
|
||||
def test_populated_test_upgrader_graph(self):
|
||||
@ -151,7 +163,7 @@ class TestUpgraders(JitTestCase):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_v7.ptl"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
|
||||
for (a, b) in sample_inputs:
|
||||
for a, b in sample_inputs:
|
||||
output_with_step, output_without_step = loaded_model(a, b)
|
||||
# when no step is given, should have used 100
|
||||
self.assertTrue(output_without_step.size(dim=0) == 100)
|
||||
@ -161,7 +173,9 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertTrue(version == 8)
|
||||
|
||||
def test_aten_linspace_out(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_linspace_out_v7.ptl"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
sample_inputs = (
|
||||
(3, 10, torch.empty((100,), dtype=torch.int64)),
|
||||
@ -169,7 +183,7 @@ class TestUpgraders(JitTestCase):
|
||||
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
|
||||
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
|
||||
)
|
||||
for (a, b, c) in sample_inputs:
|
||||
for a, b, c in sample_inputs:
|
||||
output = loaded_model(a, b, c)
|
||||
# when no step is given, should have used 100
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
@ -181,7 +195,7 @@ class TestUpgraders(JitTestCase):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_v8.ptl"
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
sample_inputs = ((3, 10), (-10, 10), (4.0, 6.0), (3 + 4j, 4 + 5j))
|
||||
for (a, b) in sample_inputs:
|
||||
for a, b in sample_inputs:
|
||||
output_with_step, output_without_step = loaded_model(a, b)
|
||||
# when no step is given, should have used 100
|
||||
self.assertTrue(output_without_step.size(dim=0) == 100)
|
||||
@ -191,7 +205,9 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertTrue(version == 9)
|
||||
|
||||
def test_aten_logspace_out(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_logspace_out_v8.ptl"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
sample_inputs = (
|
||||
(3, 10, torch.empty((100,), dtype=torch.int64)),
|
||||
@ -199,7 +215,7 @@ class TestUpgraders(JitTestCase):
|
||||
(4.0, 6.0, torch.empty((100,), dtype=torch.float64)),
|
||||
(3 + 4j, 4 + 5j, torch.empty((100,), dtype=torch.complex64)),
|
||||
)
|
||||
for (a, b, c) in sample_inputs:
|
||||
for a, b, c in sample_inputs:
|
||||
output = loaded_model(a, b, c)
|
||||
# when no step is given, should have used 100
|
||||
self.assertTrue(output.size(dim=0) == 100)
|
||||
@ -208,21 +224,36 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertTrue(version == 9)
|
||||
|
||||
def test_aten_test_serialization(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/_test_serialization_subcmul_v2.pt"
|
||||
)
|
||||
|
||||
# add test version entry to the version map
|
||||
upgrader_bumped_version = 3
|
||||
upgrader_name = "_test_serialization_subcmul_0_2"
|
||||
upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
|
||||
dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
|
||||
dummy_entry = torch._C._UpgraderEntry(
|
||||
upgrader_bumped_version, upgrader_name, upgrader_schema
|
||||
)
|
||||
|
||||
torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
|
||||
torch._C._test_only_add_entry_to_op_version_map(
|
||||
"aten::_test_serialization_subcmul", dummy_entry
|
||||
)
|
||||
|
||||
# add test upgrader in the upgraders map
|
||||
@torch.jit.script
|
||||
def _test_serialization_subcmul_0_2(self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2) -> torch.Tensor:
|
||||
def _test_serialization_subcmul_0_2(
|
||||
self: torch.Tensor, other: torch.Tensor, alpha: Union[int, float] = 2
|
||||
) -> torch.Tensor:
|
||||
return other - (self * alpha)
|
||||
torch._C._test_only_populate_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||
|
||||
torch._C._test_only_populate_upgraders(
|
||||
{
|
||||
"_test_serialization_subcmul_0_2": str(
|
||||
_test_serialization_subcmul_0_2.graph
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
# test if the server is able to find the test upgraders and apply to IR
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
@ -238,11 +269,21 @@ class TestUpgraders(JitTestCase):
|
||||
# we check by its' code because graph variable names
|
||||
# can be different every time
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
|
||||
torch._C._test_only_remove_upgraders({"_test_serialization_subcmul_0_2": str(_test_serialization_subcmul_0_2.graph)})
|
||||
torch._C._test_only_remove_entry_to_op_version_map(
|
||||
"aten::_test_serialization_subcmul"
|
||||
)
|
||||
torch._C._test_only_remove_upgraders(
|
||||
{
|
||||
"_test_serialization_subcmul_0_2": str(
|
||||
_test_serialization_subcmul_0_2.graph
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def test_aten_div_scalar_at_3(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_div_scalar_float_v3.pt"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||
@ -254,11 +295,15 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertEqual(version, 4)
|
||||
loaded_model_twice = torch.jit.load(buffer)
|
||||
|
||||
self.assertEqual(loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
|
||||
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0))
|
||||
self.assertEqual(
|
||||
loaded_model(torch.Tensor([5.0, 3.0]), 2.0),
|
||||
loaded_model_twice(torch.Tensor([5.0, 3.0]), 2.0),
|
||||
)
|
||||
|
||||
def test_aten_div_tensor_out_at_3(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_div_tensor_out_v3.pt"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check("prim::If").run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::div", 2).run(loaded_model.graph)
|
||||
@ -274,7 +319,9 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
|
||||
def test_aten_full_at_4(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_full_integer_value_v4.pt"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check_count("aten::Float", 1).run(loaded_model.graph)
|
||||
FileCheck().check_count("aten::full", 2).run(loaded_model.graph)
|
||||
@ -290,7 +337,9 @@ class TestUpgraders(JitTestCase):
|
||||
self.assertEqual(loaded_model.code, loaded_model_twice.code)
|
||||
|
||||
def test_aten_full_out_at_4(self):
|
||||
model_path = pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
|
||||
model_path = (
|
||||
pytorch_test_dir + "/jit/fixtures/test_versioned_full_preserved_v4.pt"
|
||||
)
|
||||
loaded_model = torch.jit.load(model_path)
|
||||
FileCheck().check_count("aten::full", 5).run(loaded_model.graph)
|
||||
version = self._load_model_version(loaded_model)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from contextlib import redirect_stderr
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -14,10 +14,12 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
class TestWarn(JitTestCase):
|
||||
@ -30,12 +32,9 @@ class TestWarn(JitTestCase):
|
||||
with redirect_stderr(f):
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=1,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=1, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_only_once(self):
|
||||
@torch.jit.script
|
||||
@ -47,12 +46,9 @@ class TestWarn(JitTestCase):
|
||||
with redirect_stderr(f):
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=1,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=1, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_only_once_in_loop_func(self):
|
||||
def w():
|
||||
@ -67,12 +63,9 @@ class TestWarn(JitTestCase):
|
||||
with redirect_stderr(f):
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=1,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=1, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_once_per_func(self):
|
||||
def w1():
|
||||
@ -90,12 +83,9 @@ class TestWarn(JitTestCase):
|
||||
with redirect_stderr(f):
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=2,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=2, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_once_per_func_in_loop(self):
|
||||
def w1():
|
||||
@ -114,12 +104,9 @@ class TestWarn(JitTestCase):
|
||||
with redirect_stderr(f):
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=2,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=2, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_multiple_calls_multiple_warnings(self):
|
||||
@torch.jit.script
|
||||
@ -131,12 +118,9 @@ class TestWarn(JitTestCase):
|
||||
fn()
|
||||
fn()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you",
|
||||
count=2,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you", count=2, exactly=True
|
||||
).run(f.getvalue())
|
||||
|
||||
def test_warn_multiple_calls_same_func_diff_stack(self):
|
||||
def warn(caller: str):
|
||||
@ -155,13 +139,10 @@ class TestWarn(JitTestCase):
|
||||
foo()
|
||||
bar()
|
||||
|
||||
FileCheck() \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you from foo",
|
||||
count=1,
|
||||
exactly=True) \
|
||||
.check_count(
|
||||
str="UserWarning: I am warning you from bar",
|
||||
count=1,
|
||||
exactly=True) \
|
||||
.run(f.getvalue())
|
||||
FileCheck().check_count(
|
||||
str="UserWarning: I am warning you from foo", count=1, exactly=True
|
||||
).check_count(
|
||||
str="UserWarning: I am warning you from bar", count=1, exactly=True
|
||||
).run(
|
||||
f.getvalue()
|
||||
)
|
||||
|
||||
@ -32,6 +32,7 @@ class TestWith(JitTestCase):
|
||||
Check that with statements that use the 'as' keyword to bind expressions
|
||||
to targets work as expected.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class Context:
|
||||
"""
|
||||
@ -189,6 +190,7 @@ class TestWith(JitTestCase):
|
||||
Check that with statements that do not use the 'as' keyword to bind expressions
|
||||
to targets work as expected.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class Context:
|
||||
"""
|
||||
@ -345,6 +347,7 @@ class TestWith(JitTestCase):
|
||||
Check that exceptions thrown in the bodies of with-statements are
|
||||
handled correctly.
|
||||
"""
|
||||
|
||||
@torch.jit.script
|
||||
class Context:
|
||||
"""
|
||||
@ -416,15 +419,21 @@ class TestWith(JitTestCase):
|
||||
# checkScript and checkScriptRaisesRegex cannot be used because the string frontend will
|
||||
# not compile class types (of which Context, the context manager being used for this test
|
||||
# is one).
|
||||
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, r"raised exception", 'raise Exception("raised exception'
|
||||
):
|
||||
test_exception(torch.randn(2), c)
|
||||
self.assertEqual(c.count, 1)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, r"raised exception", 'raise Exception("raised exception'
|
||||
):
|
||||
test_exception_nested(torch.randn(2), c)
|
||||
self.assertEqual(c.count, 1)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(Exception, r"raised exception", "raise Exception(\"raised exception"):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
Exception, r"raised exception", 'raise Exception("raised exception'
|
||||
):
|
||||
test_exception_fn_call(torch.randn(2), c)
|
||||
self.assertEqual(c.count, 1)
|
||||
|
||||
@ -505,7 +514,9 @@ class TestWith(JitTestCase):
|
||||
|
||||
return x
|
||||
|
||||
def test_exit_incorrect_types(x: torch.Tensor, cm: ExitIncorrectTypes) -> torch.Tensor:
|
||||
def test_exit_incorrect_types(
|
||||
x: torch.Tensor, cm: ExitIncorrectTypes
|
||||
) -> torch.Tensor:
|
||||
with cm as _:
|
||||
pass
|
||||
|
||||
@ -523,7 +534,9 @@ class TestWith(JitTestCase):
|
||||
self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit()))
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"__enter__ must have only one argument and one return value", "cm"
|
||||
RuntimeError,
|
||||
r"__enter__ must have only one argument and one return value",
|
||||
"cm",
|
||||
):
|
||||
self.checkScript(test_bad_enter, (test_tensor, BadEnter()))
|
||||
|
||||
@ -539,7 +552,9 @@ class TestWith(JitTestCase):
|
||||
test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes())
|
||||
)
|
||||
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, r"must return an object", "\"not_object\""):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError, r"must return an object", '"not_object"'
|
||||
):
|
||||
self.checkScript(test_enter_without_object, ())
|
||||
|
||||
def test_with_no_grad(self):
|
||||
@ -603,6 +618,7 @@ class TestWith(JitTestCase):
|
||||
Check that torch.autograd.profiler.record_function context manager is
|
||||
torchscriptable.
|
||||
"""
|
||||
|
||||
def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
with torch.autograd.profiler.record_function("foo"):
|
||||
# Nested record_function.
|
||||
|
||||
@ -7,6 +7,7 @@ import torch._C
|
||||
|
||||
torch.ops.load_library("//caffe2:xnnpack_backend")
|
||||
|
||||
|
||||
class TestXNNPackBackend(unittest.TestCase):
|
||||
def test_xnnpack_constant_data(self):
|
||||
class Module(torch.nn.Module):
|
||||
@ -24,17 +25,19 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
scripted_module,
|
||||
{
|
||||
"forward": {
|
||||
"inputs" : [torch.randn(4, 4, 4)],
|
||||
"outputs": [torch.randn(4, 4, 4)]
|
||||
"inputs": [torch.randn(4, 4, 4)],
|
||||
"outputs": [torch.randn(4, 4, 4)],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
for i in range(0, 20):
|
||||
sample_input = torch.randn(4, 4, 4)
|
||||
actual_output = scripted_module(sample_input)
|
||||
expected_output = lowered_module(sample_input)
|
||||
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
|
||||
)
|
||||
|
||||
def test_xnnpack_lowering(self):
|
||||
class Module(torch.nn.Module):
|
||||
@ -45,13 +48,11 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
|
||||
faulty_compile_spec = {
|
||||
"backward": {
|
||||
"inputs" : [torch.zeros(1)],
|
||||
"inputs": [torch.zeros(1)],
|
||||
"outputs": [torch.zeros(1)],
|
||||
}
|
||||
}
|
||||
error_msg = (
|
||||
"method_compile_spec does not contain the \"forward\" key."
|
||||
)
|
||||
error_msg = 'method_compile_spec does not contain the "forward" key.'
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
@ -64,21 +65,21 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
)
|
||||
|
||||
mismatch_compile_spec = {
|
||||
"forward" : {
|
||||
"inputs" : [torch.zeros(1), torch.zeros(1)],
|
||||
"outputs" : [torch.zeros(1)]
|
||||
"forward": {
|
||||
"inputs": [torch.zeros(1), torch.zeros(1)],
|
||||
"outputs": [torch.zeros(1)],
|
||||
}
|
||||
}
|
||||
error_msg = ("method_compile_spec inputs do not match expected number of forward inputs")
|
||||
error_msg = (
|
||||
"method_compile_spec inputs do not match expected number of forward inputs"
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
error_msg,
|
||||
):
|
||||
_ = torch._C._jit_to_backend(
|
||||
"xnnpack",
|
||||
scripted_module,
|
||||
mismatch_compile_spec
|
||||
"xnnpack", scripted_module, mismatch_compile_spec
|
||||
)
|
||||
|
||||
lowered = torch._C._jit_to_backend(
|
||||
@ -86,10 +87,10 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
scripted_module,
|
||||
{
|
||||
"forward": {
|
||||
"inputs" : [torch.zeros(1)],
|
||||
"inputs": [torch.zeros(1)],
|
||||
"outputs": [torch.zeros(1)],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
lowered(torch.zeros(1))
|
||||
|
||||
@ -113,14 +114,16 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
add_module,
|
||||
{
|
||||
"forward": {
|
||||
"inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()],
|
||||
"outputs": [sample_output]
|
||||
"inputs": [sample_inputs[0].clone(), sample_inputs[1].clone()],
|
||||
"outputs": [sample_output],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
|
||||
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
|
||||
)
|
||||
|
||||
def test_xnnpack_broadcasting(self):
|
||||
class AddModule(torch.nn.Module):
|
||||
@ -139,14 +142,16 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
add_module,
|
||||
{
|
||||
"forward": {
|
||||
"inputs" : [sample_inputs[0], sample_inputs[1]],
|
||||
"outputs": [sample_output]
|
||||
"inputs": [sample_inputs[0], sample_inputs[1]],
|
||||
"outputs": [sample_output],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
actual_output = lowered_add_module.forward(sample_inputs[0], sample_inputs[1])
|
||||
self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03))
|
||||
self.assertTrue(
|
||||
torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)
|
||||
)
|
||||
|
||||
def test_xnnpack_unsupported(self):
|
||||
class AddSpliceModule(torch.nn.Module):
|
||||
@ -173,8 +178,8 @@ class TestXNNPackBackend(unittest.TestCase):
|
||||
add_module,
|
||||
{
|
||||
"forward": {
|
||||
"inputs" : [sample_inputs[0], sample_inputs[1]],
|
||||
"outputs": [sample_output]
|
||||
"inputs": [sample_inputs[0], sample_inputs[1]],
|
||||
"outputs": [sample_output],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@ -1,23 +1,30 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
# grab modules from test_jit_hooks.cpp
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from jit.test_hooks_modules import (
|
||||
create_forward_tuple_input, create_module_forward_multiple_inputs,
|
||||
create_module_forward_single_input, create_module_hook_return_nothing,
|
||||
create_forward_tuple_input,
|
||||
create_module_forward_multiple_inputs,
|
||||
create_module_forward_single_input,
|
||||
create_module_hook_return_nothing,
|
||||
create_module_multiple_hooks_multiple_inputs,
|
||||
create_module_multiple_hooks_single_input, create_module_no_forward_input,
|
||||
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs,
|
||||
create_module_multiple_hooks_single_input,
|
||||
create_module_no_forward_input,
|
||||
create_module_same_hook_repeated,
|
||||
create_submodule_forward_multiple_inputs,
|
||||
create_submodule_forward_single_input,
|
||||
create_submodule_hook_return_nothing,
|
||||
create_submodule_multiple_hooks_multiple_inputs,
|
||||
create_submodule_multiple_hooks_single_input,
|
||||
create_submodule_same_hook_repeated,
|
||||
create_submodule_to_call_directly_with_hooks)
|
||||
create_submodule_to_call_directly_with_hooks,
|
||||
)
|
||||
|
||||
|
||||
# Create saved modules for JIT forward hooks and pre-hooks
|
||||
def main():
|
||||
@ -30,23 +37,45 @@ def main():
|
||||
save_name = options.export_script_module_to + "_"
|
||||
|
||||
tests = [
|
||||
("test_submodule_forward_single_input", create_submodule_forward_single_input()),
|
||||
("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()),
|
||||
("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()),
|
||||
("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()),
|
||||
(
|
||||
"test_submodule_forward_single_input",
|
||||
create_submodule_forward_single_input(),
|
||||
),
|
||||
(
|
||||
"test_submodule_forward_multiple_inputs",
|
||||
create_submodule_forward_multiple_inputs(),
|
||||
),
|
||||
(
|
||||
"test_submodule_multiple_hooks_single_input",
|
||||
create_submodule_multiple_hooks_single_input(),
|
||||
),
|
||||
(
|
||||
"test_submodule_multiple_hooks_multiple_inputs",
|
||||
create_submodule_multiple_hooks_multiple_inputs(),
|
||||
),
|
||||
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()),
|
||||
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()),
|
||||
|
||||
("test_module_forward_single_input", create_module_forward_single_input()),
|
||||
("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()),
|
||||
("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()),
|
||||
("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()),
|
||||
(
|
||||
"test_module_forward_multiple_inputs",
|
||||
create_module_forward_multiple_inputs(),
|
||||
),
|
||||
(
|
||||
"test_module_multiple_hooks_single_input",
|
||||
create_module_multiple_hooks_single_input(),
|
||||
),
|
||||
(
|
||||
"test_module_multiple_hooks_multiple_inputs",
|
||||
create_module_multiple_hooks_multiple_inputs(),
|
||||
),
|
||||
("test_module_hook_return_nothing", create_module_hook_return_nothing()),
|
||||
("test_module_same_hook_repeated", create_module_same_hook_repeated()),
|
||||
|
||||
("test_module_no_forward_input", create_module_no_forward_input()),
|
||||
("test_forward_tuple_input", create_forward_tuple_input()),
|
||||
("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks())
|
||||
(
|
||||
"test_submodule_to_call_directly_with_hooks",
|
||||
create_submodule_to_call_directly_with_hooks(),
|
||||
),
|
||||
]
|
||||
|
||||
for name, model in tests:
|
||||
|
||||
Reference in New Issue
Block a user