mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[OpenReg][5/N] add set_.source_Storage for openreg (#155191)
**Changes**: - add set_.source_Storage for openreg to support torch.load & torch.serialization - uncomment some related tests in the test_openreg.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/155191 Approved by: https://github.com/albanD ghstack dependencies: #153947, #154018, #154019, #154106, #154181, #155101
This commit is contained in:
@ -6,6 +6,7 @@
|
||||
#include <ATen/ops/as_strided_cpu_dispatch.h>
|
||||
#include <ATen/ops/quantize_per_tensor_native.h>
|
||||
#include <ATen/ops/set_cpu_dispatch.h>
|
||||
#include <ATen/ops/set_native.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
@ -107,7 +108,7 @@ at::Tensor as_strided_openreg(
|
||||
return at::cpu::as_strided(self, size, stride, storage_offset_);
|
||||
}
|
||||
|
||||
at::Tensor& set_openreg(
|
||||
at::Tensor& set_source_Storage_storage_offsetset_openreg(
|
||||
at::Tensor& result,
|
||||
at::Storage storage,
|
||||
int64_t storage_offset,
|
||||
@ -269,7 +270,8 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
m.impl("empty.memory_format", empty_openreg);
|
||||
m.impl("empty_strided", empty_strided_openreg);
|
||||
m.impl("as_strided", as_strided_openreg);
|
||||
m.impl("set_.source_Storage_storage_offset", set_openreg);
|
||||
m.impl("set_.source_Storage", at::native::set_);
|
||||
m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg);
|
||||
m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
|
||||
m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
|
||||
m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
|
||||
@ -307,12 +308,11 @@ class TestOpenReg(TestCase):
|
||||
storage = torch.UntypedStorage(4, device=torch.device("openreg:0"))
|
||||
self.assertEqual(torch.serialization.location_tag(storage), "openreg:0")
|
||||
|
||||
# Need to support torch.storage.UntypedStorage first in prepare_for_sending.convert
|
||||
# storage_cpu = torch.empty(4, 4).storage()
|
||||
# storage_openreg = torch.serialization.default_restore_location(
|
||||
# storage_cpu, "openreg:0"
|
||||
# )
|
||||
# self.assertTrue(storage_openreg.is_openreg) # type: ignore[misc]
|
||||
storage_cpu = torch.empty(4, 4).storage()
|
||||
storage_openreg = torch.serialization.default_restore_location(
|
||||
storage_cpu, "openreg:0"
|
||||
)
|
||||
self.assertTrue(storage_openreg.is_openreg) # type: ignore[misc]
|
||||
|
||||
tensor = torch.empty(3, 3, device="openreg")
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor), {}) # type: ignore[misc]
|
||||
@ -320,18 +320,17 @@ class TestOpenReg(TestCase):
|
||||
torch._utils.set_tensor_metadata(tensor, metadata) # type: ignore[misc]
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor), metadata) # type: ignore[misc]
|
||||
|
||||
# Need to support torch.storage.UntypedStorage first in prepare_for_sending.convert
|
||||
# with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# path = os.path.join(tmpdir, "data.pt")
|
||||
# torch.save(tensor, path)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "data.pt")
|
||||
torch.save(tensor, path)
|
||||
|
||||
# tensor_openreg = torch.load(path)
|
||||
# self.assertTrue(tensor_openreg.is_openreg)
|
||||
# self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata) # type: ignore[misc]
|
||||
tensor_openreg = torch.load(path)
|
||||
self.assertTrue(tensor_openreg.is_openreg)
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor_openreg), metadata) # type: ignore[misc]
|
||||
|
||||
# tensor_cpu = torch.load(path, map_location="cpu")
|
||||
# self.assertFalse(tensor_cpu.is_openreg)
|
||||
# self.assertEqual(torch._utils.get_tensor_metadata(tensor), {}) # type: ignore[misc]
|
||||
tensor_cpu = torch.load(path, map_location="cpu")
|
||||
self.assertFalse(tensor_cpu.is_openreg)
|
||||
self.assertEqual(torch._utils.get_tensor_metadata(tensor_cpu), {}) # type: ignore[misc]
|
||||
|
||||
# Opeartors
|
||||
def test_factory(self):
|
||||
|
||||
Reference in New Issue
Block a user