Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)"

This reverts commit 3da14d38bd396f5bbe8494872d1509efa1a6f048.

Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770))
This commit is contained in:
PyTorch MergeBot
2025-04-08 17:10:36 +00:00
parent 3e0038ae85
commit 4926bd6004
12 changed files with 49 additions and 49 deletions

View File

@ -4,6 +4,7 @@ import _codecs
import io
import os
import sys
import tempfile
import unittest
from typing import Union
from unittest.mock import patch
@ -345,22 +346,23 @@ class TestCppExtensionOpenRgistration(common.TestCase):
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg")
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg"))
@unittest.skip(
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
)
def test_open_device_serialization(self):
self.module.set_custom_device_index(-1)
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
self.assertEqual(torch.serialization.location_tag(storage), "openreg")
self.module.set_custom_device_index(0)
storage = torch.UntypedStorage(4, device=torch.device("openreg"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
# TODO(FFFrog): Comment this because openreg.device is missing
# Uncomment this after improving openreg
# cpu_storage = torch.empty(4, 4).storage()
# openreg_storage = torch.serialization.default_restore_location(
# cpu_storage, "openreg:0"
# )
# self.assertTrue(openreg_storage.is_openreg)
cpu_storage = torch.empty(4, 4).storage()
openreg_storage = torch.serialization.default_restore_location(
cpu_storage, "openreg:0"
)
self.assertTrue(openreg_storage.is_openreg)
# test tensor MetaData serialization
x = torch.empty(4, 4).long()
@ -369,24 +371,22 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.module.custom_set_backend_meta(y)
self.assertTrue(self.module.check_backend_meta(y))
# TODO(FFFrog): Comment this because openreg.device is missing
# Uncomment this after improving openreg
# self.module.custom_serialization_registry()
# with tempfile.TemporaryDirectory() as tmpdir:
# path = os.path.join(tmpdir, "data.pt")
# torch.save(y, path)
# z1 = torch.load(path)
# loads correctly onto the openreg backend device
# self.assertTrue(z1.is_openreg)
# loads BackendMeta data correctly
# self.assertTrue(self.module.check_backend_meta(z1))
self.module.custom_serialization_registry()
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "data.pt")
torch.save(y, path)
z1 = torch.load(path)
# loads correctly onto the openreg backend device
self.assertTrue(z1.is_openreg)
# loads BackendMeta data correctly
self.assertTrue(self.module.check_backend_meta(z1))
# cross-backend
# z2 = torch.load(path, map_location="cpu")
# loads correctly onto the cpu backend device
# self.assertFalse(z2.is_openreg)
# loads BackendMeta data correctly
# self.assertFalse(self.module.check_backend_meta(z2))
# cross-backend
z2 = torch.load(path, map_location="cpu")
# loads correctly onto the cpu backend device
self.assertFalse(z2.is_openreg)
# loads BackendMeta data correctly
self.assertFalse(self.module.check_backend_meta(z2))
def test_open_device_storage_resize(self):
cpu_tensor = torch.randn([8])

View File

@ -1,4 +1,5 @@
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/serialize.h>
#include <vector>

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc {
class TORCH_API PythonRemoteCall : public RpcCommandBase {

View File

@ -4,6 +4,7 @@
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector>
namespace torch::distributed::rpc {

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <optional>
#include <vector>

View File

@ -3,6 +3,7 @@
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <vector>
namespace torch::distributed::rpc {

View File

@ -2,6 +2,7 @@
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
#include <torch/csrc/jit/serialization/pickler.h>
namespace torch::distributed::rpc {

View File

@ -16,7 +16,6 @@
#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/jit/serialization/onnx.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/onnx/back_compat.h>
#include <torch/csrc/onnx/onnx.h>
#include <torch/version.h>

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/serialization/export_bytecode.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/python_print.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>

View File

@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
return true;
}
std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
} // namespace torch::jit

View File

@ -299,14 +299,27 @@ using BackendMetaPtr = std::function<
void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
// A allowlist of device type, currently available is PrivateUse1
TORCH_API std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist();
inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
c10::DeviceType::PrivateUse1};
return DeviceTypeAllowlist;
}
// Dynamically obtain serialization function pairs
// that require the corresponding backend.
TORCH_API std::array<
inline std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization();
GetBackendMetaSerialization() {
// The array to save function pointer for BackendMeta serialization.
// key is the DeviceType, value is std::pair obj.
// value.first represent get function and value.seconde represent set function
static std::array<
std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
at::COMPILE_TIME_MAX_DEVICE_TYPES>
BackendMetaSerialization;
return BackendMetaSerialization;
}
// Register function pointer of Tensor BackendMetadata for serialization.
TORCH_API inline void TensorBackendMetaRegistry(

View File

@ -5,6 +5,7 @@
#endif
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/csrc/utils/byte_order.h>