mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-03 23:45:05 +08:00 
			
		
		
		
	Changes by apply order:
1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.
    `.parent{...}.absolute()` -> `.absolute().parent{...}`
4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)
    `.parent.parent.parent.parent` -> `.parents[3]`
5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~
    ~`.parents[3]` -> `.parents[4 - 1]`~
6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
		
	
		
			
				
	
	
		
			142 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			142 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Owner(s): ["oncall: jit"]
 | 
						|
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import unittest
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
import torch
 | 
						|
import torch._C
 | 
						|
from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
 | 
						|
 | 
						|
 | 
						|
# hacky way to skip these tests in fbcode:
 | 
						|
# during test execution in fbcode, test_nnapi is available during test discovery,
 | 
						|
# but not during test execution. So we can't try-catch here, otherwise it'll think
 | 
						|
# it sees tests but then fails when it tries to actuall run them.
 | 
						|
if not IS_FBCODE:
 | 
						|
    from test_nnapi import TestNNAPI
 | 
						|
 | 
						|
    HAS_TEST_NNAPI = True
 | 
						|
else:
 | 
						|
    from torch.testing._internal.common_utils import TestCase as TestNNAPI
 | 
						|
 | 
						|
    HAS_TEST_NNAPI = False
 | 
						|
 | 
						|
 | 
						|
# Make the helper files in test/ importable
 | 
						|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
 | 
						|
sys.path.append(pytorch_test_dir)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    raise RuntimeError(
 | 
						|
        "This test file is not meant to be run directly, use:\n\n"
 | 
						|
        "\tpython test/test_jit.py TESTNAME\n\n"
 | 
						|
        "instead."
 | 
						|
    )
 | 
						|
 | 
						|
"""
 | 
						|
Unit Tests for Nnapi backend with delegate
 | 
						|
Inherits most tests from TestNNAPI, which loads Android NNAPI models
 | 
						|
without the delegate API.
 | 
						|
"""
 | 
						|
# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
 | 
						|
torch_root = Path(__file__).resolve().parents[2]
 | 
						|
lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
 | 
						|
 | 
						|
 | 
						|
@skipIfTorchDynamo("weird py38 failures")
 | 
						|
@unittest.skipIf(
 | 
						|
    not os.path.exists(lib_path),
 | 
						|
    "Skipping the test as libnnapi_backend.so was not found",
 | 
						|
)
 | 
						|
@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
 | 
						|
class TestNnapiBackend(TestNNAPI):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
 | 
						|
        # Save default dtype
 | 
						|
        module = torch.nn.PReLU()
 | 
						|
        self.default_dtype = module.weight.dtype
 | 
						|
        # Change dtype to float32 (since a different unit test changed dtype to float64,
 | 
						|
        # which is not supported by the Android NNAPI delegate)
 | 
						|
        # Float32 should typically be the default in other files.
 | 
						|
        torch.set_default_dtype(torch.float32)
 | 
						|
 | 
						|
        # Load nnapi delegate library
 | 
						|
        torch.ops.load_library(str(lib_path))
 | 
						|
 | 
						|
    # Override
 | 
						|
    def call_lowering_to_nnapi(self, traced_module, args):
 | 
						|
        compile_spec = {"forward": {"inputs": args}}
 | 
						|
        return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
 | 
						|
 | 
						|
    def test_tensor_input(self):
 | 
						|
        # Lower a simple module
 | 
						|
        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
 | 
						|
        module = torch.nn.PReLU()
 | 
						|
        traced = torch.jit.trace(module, args)
 | 
						|
 | 
						|
        # Argument input is a single Tensor
 | 
						|
        self.call_lowering_to_nnapi(traced, args)
 | 
						|
        # Argument input is a Tensor in a list
 | 
						|
        self.call_lowering_to_nnapi(traced, [args])
 | 
						|
 | 
						|
    # Test exceptions for incorrect compile specs
 | 
						|
    def test_compile_spec_santiy(self):
 | 
						|
        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
 | 
						|
        module = torch.nn.PReLU()
 | 
						|
        traced = torch.jit.trace(module, args)
 | 
						|
 | 
						|
        errorMsgTail = r"""
 | 
						|
method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder.
 | 
						|
For input shapes, use 0 for run/load time flexible input.
 | 
						|
method_compile_spec must use the following format:
 | 
						|
{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}"""
 | 
						|
 | 
						|
        # No forward key
 | 
						|
        compile_spec = {"backward": {"inputs": args}}
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
 | 
						|
        ):
 | 
						|
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
 | 
						|
 | 
						|
        # No dictionary under the forward key
 | 
						|
        compile_spec = {"forward": 1}
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            'method_compile_spec does not contain a dictionary with an "inputs" key, '
 | 
						|
            'under it\'s "forward" key.' + errorMsgTail,
 | 
						|
        ):
 | 
						|
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
 | 
						|
 | 
						|
        # No inputs key (in the dictionary under the forward key)
 | 
						|
        compile_spec = {"forward": {"not inputs": args}}
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            'method_compile_spec does not contain a dictionary with an "inputs" key, '
 | 
						|
            'under it\'s "forward" key.' + errorMsgTail,
 | 
						|
        ):
 | 
						|
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
 | 
						|
 | 
						|
        # No Tensor or TensorList under the inputs key
 | 
						|
        compile_spec = {"forward": {"inputs": 1}}
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
 | 
						|
            + errorMsgTail,
 | 
						|
        ):
 | 
						|
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
 | 
						|
        compile_spec = {"forward": {"inputs": [1]}}
 | 
						|
        with self.assertRaisesRegex(
 | 
						|
            RuntimeError,
 | 
						|
            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
 | 
						|
            + errorMsgTail,
 | 
						|
        ):
 | 
						|
            torch._C._jit_to_backend("nnapi", traced, compile_spec)
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        # Change dtype back to default (Otherwise, other unit tests will complain)
 | 
						|
        torch.set_default_dtype(self.default_dtype)
 |