[OpenReg][5/N] add set_.source_Storage for openreg (#155191)

**Changes**:
- add set_.source_Storage for openreg to support torch.load & torch.serialization
- uncomment some related tests in the test_openreg.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155191
Approved by: https://github.com/albanD
ghstack dependencies: #153947, #154018, #154019, #154106, #154181, #155101
This commit is contained in:
FFFrog
2025-06-13 16:41:12 +08:00
committed by PyTorch MergeBot
parent e4fd0bf771
commit 187828dcb4
2 changed files with 19 additions and 18 deletions

View File

@ -6,6 +6,7 @@
#include <ATen/ops/as_strided_cpu_dispatch.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/set_cpu_dispatch.h>
#include <ATen/ops/set_native.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>
@ -107,7 +108,7 @@ at::Tensor as_strided_openreg(
return at::cpu::as_strided(self, size, stride, storage_offset_);
}
at::Tensor& set_openreg(
at::Tensor& set_source_Storage_storage_offsetset_openreg(
at::Tensor& result,
at::Storage storage,
int64_t storage_offset,
@ -269,7 +270,8 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
m.impl("empty.memory_format", empty_openreg);
m.impl("empty_strided", empty_strided_openreg);
m.impl("as_strided", as_strided_openreg);
m.impl("set_.source_Storage_storage_offset", set_openreg);
m.impl("set_.source_Storage", at::native::set_);
m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg);
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: PrivateUse1"]
import os
import tempfile
import types
import unittest
@ -307,12 +308,11 @@ class TestOpenReg(TestCase):
storage = torch.UntypedStorage(4, device=torch.device("openreg:0"))
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
# Need to support torch.storage.UntypedStorage first in prepare_for_sending.convert
# storage_cpu = torch.empty(4, 4).storage()
# storage_openreg = torch.serialization.default_restore_location(
# storage_cpu, "openreg:0"
# )
# self.assertTrue(storage_openreg.is_openreg) # type: ignore[misc]
storage_cpu = torch.empty(4, 4).storage()
storage_openreg = torch.serialization.default_restore_location(
storage_cpu, "openreg:0"
)
self.assertTrue(storage_openreg.is_openreg) # type: ignore[misc]
tensor = torch.empty(3, 3, device="openreg")
self.assertEqual(torch._utils.get_tensor_metadata(tensor), {}) # type: ignore[misc]
@ -320,18 +320,17 @@ class TestOpenReg(TestCase):
torch._utils.set_tensor_metadata(tensor, metadata) # type: ignore[misc]
self.assertEqual(torch._utils.get_tensor_metadata(tensor), metadata) # type: ignore[misc]
# Need to support torch.storage.UntypedStorage first in prepare_for_sending.convert
# with tempfile.TemporaryDirectory() as tmpdir:
# path = os.path.join(tmpdir, "data.pt")
# torch.save(tensor, path)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "data.pt")
torch.save(tensor, path)
# tensor_openreg = torch.load(path)
# self.assertTrue(tensor_openreg.is_openreg)
# self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata) # type: ignore[misc]
tensor_openreg = torch.load(path)
self.assertTrue(tensor_openreg.is_openreg)
self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata) # type: ignore[misc]
# tensor_cpu = torch.load(path, map_location="cpu")
# self.assertFalse(tensor_cpu.is_openreg)
# self.assertEqual(torch._utils.get_tensor_metadata(tensor), {}) # type: ignore[misc]
tensor_cpu = torch.load(path, map_location="cpu")
self.assertFalse(tensor_cpu.is_openreg)
self.assertEqual(torch._utils.get_tensor_metadata(tensor_cpu), {}) # type: ignore[misc]
# Opeartors
def test_factory(self):