Fix and unskip cpp extension tests for ARM (#83115)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83115
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2022-08-11 11:39:23 -04:00
committed by PyTorch MergeBot
parent 958651327f
commit b18962552e
3 changed files with 4 additions and 5 deletions

View File

@ -1,13 +1,12 @@
# Owner(s): ["module: dispatch"]
import torch._C as C
from torch.testing._internal.common_utils import TestCase, run_tests, IS_ARM64
from torch.testing._internal.common_utils import TestCase, run_tests
from torch._python_dispatcher import PythonDispatcher
from collections import namedtuple
import itertools
import os
import unittest
import re
import torch.utils.cpp_extension
@ -768,7 +767,6 @@ CompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> Tensor _0 [ b
msg=f"Expect zero dangling impls, but found: {dangling_impls}"
)
@unittest.skipIf(IS_ARM64, "Not working on arm")
def test_find_dangling_impls_ext(self):
extension_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cpp_extensions', 'dangling_impl_extension.cpp')
module = torch.utils.cpp_extension.load(

View File

@ -19,7 +19,7 @@ from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.utils.cpp_extension
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
from torch.testing._internal.common_utils import load_tests, IS_SANDCASTLE, IS_WINDOWS, IS_ARM64
from torch.testing._internal.common_utils import load_tests, IS_SANDCASTLE, IS_WINDOWS
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -682,7 +682,6 @@ class TestAssert(TestCase):
@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only")
class TestStandaloneCPPJIT(TestCase):
@unittest.skipIf(IS_ARM64, "Not working on arm")
def test_load_standalone(self):
build_dir = tempfile.mkdtemp()
try:

View File

@ -37,6 +37,8 @@ except subprocess.CalledProcessError:
except FileNotFoundError:
# Do not print warning. This is okay. This file can also be imported for non-ROCm builds.
pass
except PermissionError:
pass
rocm_version = (0, 0, 0)
rocm_version_h = f"{rocm_path}/include/rocm_version.h"