mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)"
This reverts commit 3da14d38bd396f5bbe8494872d1509efa1a6f048. Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770))
This commit is contained in:
@ -4,6 +4,7 @@ import _codecs
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
@ -345,22 +346,23 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
|
||||
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
|
||||
|
||||
@unittest.skip(
|
||||
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
|
||||
)
|
||||
def test_open_device_serialization(self):
|
||||
self.module.set_custom_device_index(-1)
|
||||
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg")
|
||||
|
||||
self.module.set_custom_device_index(0)
|
||||
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||
|
||||
# TODO(FFFrog): Comment this because openreg.device is missing
|
||||
# Uncomment this after improving openreg
|
||||
# cpu_storage = torch.empty(4, 4).storage()
|
||||
# openreg_storage = torch.serialization.default_restore_location(
|
||||
# cpu_storage, "openreg:0"
|
||||
# )
|
||||
# self.assertTrue(openreg_storage.is_openreg)
|
||||
cpu_storage = torch.empty(4, 4).storage()
|
||||
openreg_storage = torch.serialization.default_restore_location(
|
||||
cpu_storage, "openreg:0"
|
||||
)
|
||||
self.assertTrue(openreg_storage.is_openreg)
|
||||
|
||||
# test tensor MetaData serialization
|
||||
x = torch.empty(4, 4).long()
|
||||
@ -369,24 +371,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
self.module.custom_set_backend_meta(y)
|
||||
self.assertTrue(self.module.check_backend_meta(y))
|
||||
|
||||
# TODO(FFFrog): Comment this because openreg.device is missing
|
||||
# Uncomment this after improving openreg
|
||||
# self.module.custom_serialization_registry()
|
||||
# with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# path = os.path.join(tmpdir, "data.pt")
|
||||
# torch.save(y, path)
|
||||
# z1 = torch.load(path)
|
||||
# loads correctly onto the openreg backend device
|
||||
# self.assertTrue(z1.is_openreg)
|
||||
# loads BackendMeta data correctly
|
||||
# self.assertTrue(self.module.check_backend_meta(z1))
|
||||
self.module.custom_serialization_registry()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "data.pt")
|
||||
torch.save(y, path)
|
||||
z1 = torch.load(path)
|
||||
# loads correctly onto the openreg backend device
|
||||
self.assertTrue(z1.is_openreg)
|
||||
# loads BackendMeta data correctly
|
||||
self.assertTrue(self.module.check_backend_meta(z1))
|
||||
|
||||
# cross-backend
|
||||
# z2 = torch.load(path, map_location="cpu")
|
||||
# loads correctly onto the cpu backend device
|
||||
# self.assertFalse(z2.is_openreg)
|
||||
# loads BackendMeta data correctly
|
||||
# self.assertFalse(self.module.check_backend_meta(z2))
|
||||
# cross-backend
|
||||
z2 = torch.load(path, map_location="cpu")
|
||||
# loads correctly onto the cpu backend device
|
||||
self.assertFalse(z2.is_openreg)
|
||||
# loads BackendMeta data correctly
|
||||
self.assertFalse(self.module.check_backend_meta(z2))
|
||||
|
||||
def test_open_device_storage_resize(self):
|
||||
cpu_tensor = torch.randn([8])
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/jit/serialization/pickle.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/serialize.h>
|
||||
|
||||
#include <vector>
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
namespace torch::distributed::rpc {
|
||||
|
||||
class TORCH_API PythonRemoteCall : public RpcCommandBase {
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <torch/csrc/distributed/rpc/script_call.h>
|
||||
#include <torch/csrc/distributed/rpc/types.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <vector>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/distributed/rpc/message.h>
|
||||
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
|
||||
namespace torch::distributed::rpc {
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/onnx/back_compat.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
#include <torch/version.h>
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/serialization/export_bytecode.h>
|
||||
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/python_print.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
|
||||
|
@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||
c10::DeviceType::PrivateUse1};
|
||||
return DeviceTypeAllowlist;
|
||||
}
|
||||
|
||||
std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||
GetBackendMetaSerialization() {
|
||||
// The array to save function pointer for BackendMeta serialization.
|
||||
// key is the DeviceType, value is std::pair obj.
|
||||
// value.first represent get function and value.seconde represent set function
|
||||
static std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
BackendMetaSerialization;
|
||||
return BackendMetaSerialization;
|
||||
}
|
||||
|
||||
} // namespace torch::jit
|
||||
|
@ -299,14 +299,27 @@ using BackendMetaPtr = std::function<
|
||||
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
|
||||
|
||||
// A allowlist of device type, currently available is PrivateUse1
|
||||
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
|
||||
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
|
||||
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
|
||||
c10::DeviceType::PrivateUse1};
|
||||
return DeviceTypeAllowlist;
|
||||
}
|
||||
|
||||
// Dynamically obtain serialization function pairs
|
||||
// that require the corresponding backend.
|
||||
TORCH_API std::array<
|
||||
inline std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
|
||||
GetBackendMetaSerialization();
|
||||
GetBackendMetaSerialization() {
|
||||
// The array to save function pointer for BackendMeta serialization.
|
||||
// key is the DeviceType, value is std::pair obj.
|
||||
// value.first represent get function and value.seconde represent set function
|
||||
static std::array<
|
||||
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
BackendMetaSerialization;
|
||||
return BackendMetaSerialization;
|
||||
}
|
||||
|
||||
// Register function pointer of Tensor BackendMetadata for serialization.
|
||||
TORCH_API inline void TensorBackendMetaRegistry(
|
||||
|
@ -5,6 +5,7 @@
|
||||
#endif
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/mobile/type_parser.h>
|
||||
#include <torch/csrc/jit/serialization/pickler.h>
|
||||
#include <torch/csrc/jit/serialization/storage_context.h>
|
||||
#include <torch/csrc/jit/serialization/unpickler.h>
|
||||
#include <torch/csrc/utils/byte_order.h>
|
||||
|
Reference in New Issue
Block a user