mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix and unskip cpp extension tests for ARM (#83115)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83115 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
958651327f
commit
b18962552e
@ -1,13 +1,12 @@
|
||||
# Owner(s): ["module: dispatch"]
|
||||
|
||||
import torch._C as C
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_ARM64
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch._python_dispatcher import PythonDispatcher
|
||||
|
||||
from collections import namedtuple
|
||||
import itertools
|
||||
import os
|
||||
import unittest
|
||||
import re
|
||||
import torch.utils.cpp_extension
|
||||
|
||||
@ -768,7 +767,6 @@ CompositeImplicitAutograd[alias] (inactive): fn1 :: (Tensor _0) -> Tensor _0 [ b
|
||||
msg=f"Expect zero dangling impls, but found: {dangling_impls}"
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm")
|
||||
def test_find_dangling_impls_ext(self):
|
||||
extension_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cpp_extensions', 'dangling_impl_extension.cpp')
|
||||
module = torch.utils.cpp_extension.load(
|
||||
|
@ -19,7 +19,7 @@ from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
import torch.utils.cpp_extension
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
|
||||
from torch.testing._internal.common_utils import load_tests, IS_SANDCASTLE, IS_WINDOWS, IS_ARM64
|
||||
from torch.testing._internal.common_utils import load_tests, IS_SANDCASTLE, IS_WINDOWS
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -682,7 +682,6 @@ class TestAssert(TestCase):
|
||||
|
||||
@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only")
|
||||
class TestStandaloneCPPJIT(TestCase):
|
||||
@unittest.skipIf(IS_ARM64, "Not working on arm")
|
||||
def test_load_standalone(self):
|
||||
build_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
|
@ -37,6 +37,8 @@ except subprocess.CalledProcessError:
|
||||
except FileNotFoundError:
|
||||
# Do not print warning. This is okay. This file can also be imported for non-ROCm builds.
|
||||
pass
|
||||
except PermissionError:
|
||||
pass
|
||||
|
||||
rocm_version = (0, 0, 0)
|
||||
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
|
||||
|
Reference in New Issue
Block a user