mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import re
|
|
|
|
import torch
|
|
import torch._lazy.metrics as metrics
|
|
import torch._lazy.ts_backend
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
torch._lazy.ts_backend.init()
|
|
|
|
NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+")
|
|
|
|
|
|
class LazyFuncionalizationTest(TestCase):
|
|
def test_lazy_init_with_view(self):
|
|
def f(device, reset_storage=False):
|
|
torch.manual_seed(2023)
|
|
|
|
if device == "lazy":
|
|
metrics.reset()
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(4, 2, bias=False)
|
|
|
|
def forward(self, x):
|
|
return x @ self.fc1.weight.transpose(0, 1)
|
|
|
|
with torch.device(device):
|
|
model = Model()
|
|
|
|
if device == "lazy":
|
|
if reset_storage:
|
|
torch._C._unsafe_reset_storage(model.fc1.weight)
|
|
|
|
torch._lazy.mark_step()
|
|
|
|
sync_tensors = metrics.counter_value("SyncedTensorsWithIR")
|
|
if reset_storage:
|
|
assert sync_tensors == 1
|
|
else:
|
|
# There is an extra tensor being unnecessarily synced if
|
|
# the functional storage is not reset.
|
|
assert sync_tensors == 2
|
|
|
|
x = torch.ones(4)
|
|
out = model(x)
|
|
|
|
if device == "lazy":
|
|
torch._lazy.mark_step()
|
|
|
|
return out
|
|
|
|
cpu_out = f("cpu")
|
|
lazy_out_1 = f("lazy", reset_storage=False)
|
|
lazy_out_2 = f("lazy", reset_storage=True)
|
|
|
|
self.assertEqual(cpu_out, lazy_out_1.to("cpu"))
|
|
self.assertEqual(cpu_out, lazy_out_2.to("cpu"))
|
|
|
|
def test_data_assign(self):
|
|
def text(lazyt):
|
|
raw = torch._C._lazy._get_tensors_text([lazyt])
|
|
return NODE_TYPE_PATTERN.sub("", raw)
|
|
|
|
origin = torch.rand(3, dtype=torch.float32)
|
|
tensor = origin.to("lazy")
|
|
|
|
self.assertExpectedInline(
|
|
text(tensor),
|
|
"""\
|
|
IR {
|
|
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0
|
|
}
|
|
""",
|
|
)
|
|
|
|
# Modify the data-type of tensor, and assign it to 'data'.
|
|
# This should update the inner tensor of FunctionalTensorWrapper,
|
|
# changing the corresponding IR node.
|
|
modified_tensor = tensor.to(torch.bfloat16)
|
|
tensor.data = modified_tensor
|
|
|
|
self.assertExpectedInline(
|
|
text(tensor),
|
|
"""\
|
|
IR {
|
|
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0
|
|
%1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0
|
|
}
|
|
""", # noqa: B950
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|