mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Make open device registration tests standalone (#153855)"
This reverts commit 8823138e47a3200c313f6bf2d21eb689d8150f39.
Reverted https://github.com/pytorch/pytorch/pull/153855 on behalf of https://github.com/clee2000 due to causing some linux aarch64 tests to fail [GH job link](https://github.com/pytorch/pytorch/actions/runs/15566289293/job/43832373302) [HUD commit link](8823138e47
), should be easy fix, rename in places where its mentioned, there might be more than just aarch64 though ([comment](https://github.com/pytorch/pytorch/pull/153855#issuecomment-2960191503))
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import distutils.command.clean
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
@ -40,8 +41,11 @@ if __name__ == "__main__":
|
||||
CXX_FLAGS = ["/sdl"]
|
||||
else:
|
||||
CXX_FLAGS = ["/sdl", "/permissive-"]
|
||||
else:
|
||||
elif platform.machine() == "s390x":
|
||||
# no -Werror on s390x due to newer compiler
|
||||
CXX_FLAGS = {"cxx": ["-g", "-Wall"]}
|
||||
else:
|
||||
CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]}
|
||||
|
||||
sources = list(CSRS_DIR.glob("*.cpp"))
|
||||
|
||||
|
@ -1236,6 +1236,7 @@ CUSTOM_HANDLERS = {
|
||||
"test_ci_sanity_check_fail": run_ci_sanity_check,
|
||||
"test_autoload_enable": test_autoload_enable,
|
||||
"test_autoload_disable": test_autoload_disable,
|
||||
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
|
||||
"test_openreg": run_test_with_openreg,
|
||||
"test_transformers_privateuse1": run_test_with_openreg,
|
||||
}
|
||||
|
@ -1,16 +1,15 @@
|
||||
# Owner(s): ["module: cpp-extensions"]
|
||||
|
||||
import _codecs
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytorch_openreg # noqa: F401
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
@ -61,21 +60,18 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# load custom device extension
|
||||
extension_root = Path(__file__).parent.parent
|
||||
torch.testing._internal.common_utils.remove_cpp_extensions_build_root()
|
||||
|
||||
cls.module = torch.utils.cpp_extension.load(
|
||||
name="custom_device_extension",
|
||||
sources=[
|
||||
f"{extension_root}/custom_device/csrc/extension.cpp",
|
||||
"cpp_extensions/open_registration_extension.cpp",
|
||||
],
|
||||
extra_include_paths=[],
|
||||
extra_include_paths=["cpp_extensions"],
|
||||
extra_cflags=["-g"],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# install / load pytorch_openreg extension
|
||||
common.install_cpp_extension(extension_root=extension_root)
|
||||
globals()["pytorch_openreg"] = importlib.import_module("pytorch_openreg")
|
||||
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
|
||||
|
||||
def test_base_device_registration(self):
|
Reference in New Issue
Block a user