mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Gets rid of all the single test excludes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117765 Approved by: https://github.com/voznesenskym
126 lines
5.3 KiB
Python
126 lines
5.3 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._C
|
|
from pathlib import Path
|
|
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
|
|
|
|
# hacky way to skip these tests in fbcode:
|
|
# during test execution in fbcode, test_nnapi is available during test discovery,
|
|
# but not during test execution. So we can't try-catch here, otherwise it'll think
|
|
# it sees tests but then fails when it tries to actuall run them.
|
|
if not IS_FBCODE:
|
|
from test_nnapi import TestNNAPI
|
|
HAS_TEST_NNAPI = True
|
|
else:
|
|
from torch.testing._internal.common_utils import TestCase as TestNNAPI
|
|
HAS_TEST_NNAPI = False
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
"""
|
|
Unit Tests for Nnapi backend with delegate
|
|
Inherits most tests from TestNNAPI, which loads Android NNAPI models
|
|
without the delegate API.
|
|
"""
|
|
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
|
|
torch_root = Path(__file__).resolve().parent.parent.parent
|
|
lib_path = torch_root / 'build' / 'lib' / 'libnnapi_backend.so'
|
|
@skipIfTorchDynamo("weird py38 failures")
|
|
@unittest.skipIf(not os.path.exists(lib_path),
|
|
"Skipping the test as libnnapi_backend.so was not found")
|
|
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
|
|
class TestNnapiBackend(TestNNAPI):
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
# Save default dtype
|
|
module = torch.nn.PReLU()
|
|
self.default_dtype = module.weight.dtype
|
|
# Change dtype to float32 (since a different unit test changed dtype to float64,
|
|
# which is not supported by the Android NNAPI delegate)
|
|
# Float32 should typically be the default in other files.
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
# Load nnapi delegate library
|
|
torch.ops.load_library(str(lib_path))
|
|
|
|
# Override
|
|
def call_lowering_to_nnapi(self, traced_module, args):
|
|
compile_spec = {"forward": {"inputs": args}}
|
|
return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
|
|
|
|
def test_tensor_input(self):
|
|
# Lower a simple module
|
|
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
|
|
module = torch.nn.PReLU()
|
|
traced = torch.jit.trace(module, args)
|
|
|
|
# Argument input is a single Tensor
|
|
self.call_lowering_to_nnapi(traced, args)
|
|
# Argument input is a Tensor in a list
|
|
self.call_lowering_to_nnapi(traced, [args])
|
|
|
|
# Test exceptions for incorrect compile specs
|
|
def test_compile_spec_santiy(self):
|
|
args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
|
|
module = torch.nn.PReLU()
|
|
traced = torch.jit.trace(module, args)
|
|
|
|
errorMsgTail = r"""
|
|
method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder.
|
|
For input shapes, use 0 for run/load time flexible input.
|
|
method_compile_spec must use the following format:
|
|
{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}"""
|
|
|
|
# No forward key
|
|
compile_spec = {"backward": {"inputs": args}}
|
|
with self.assertRaisesRegex(RuntimeError, "method_compile_spec does not contain the \"forward\" key." + errorMsgTail):
|
|
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
|
|
|
# No dictionary under the forward key
|
|
compile_spec = {"forward": 1}
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
|
|
"under it's \"forward\" key."
|
|
+ errorMsgTail):
|
|
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
|
|
|
# No inputs key (in the dictionary under the forward key)
|
|
compile_spec = {"forward": {"not inputs": args}}
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"method_compile_spec does not contain a dictionary with an \"inputs\" key, "
|
|
"under it's \"forward\" key."
|
|
+ errorMsgTail):
|
|
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
|
|
|
# No Tensor or TensorList under the inputs key
|
|
compile_spec = {"forward": {"inputs": 1}}
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
|
|
+ errorMsgTail):
|
|
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
|
compile_spec = {"forward": {"inputs": [1]}}
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"method_compile_spec does not contain either a Tensor or TensorList, under it's \"inputs\" key."
|
|
+ errorMsgTail):
|
|
torch._C._jit_to_backend("nnapi", traced, compile_spec)
|
|
|
|
def tearDown(self):
|
|
# Change dtype back to default (Otherwise, other unit tests will complain)
|
|
torch.set_default_dtype(self.default_dtype)
|