mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144556 Approved by: https://github.com/ezyang
105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
import torch._lazy.metrics as metrics
|
|
import torch._lazy.ts_backend
|
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
|
|
|
|
|
torch._lazy.ts_backend.init()
|
|
|
|
|
|
class LazyGeneratorTest(TestCase):
|
|
def test_generator(self):
|
|
"""
|
|
Test that generators are being inserted into the TorchScript
|
|
graph by setting different seeds before each call to
|
|
generate_tensor but the resulting tensor is the same
|
|
"""
|
|
|
|
def generate_tensor():
|
|
g1 = torch.Generator()
|
|
g1.manual_seed(2023)
|
|
t1 = torch.tensor(1.0)
|
|
t1.uniform_(generator=g1)
|
|
|
|
g2 = torch.Generator()
|
|
g2.manual_seed(2024)
|
|
t2 = torch.tensor(1.0)
|
|
t2.normal_(generator=g2)
|
|
|
|
return t1, t2
|
|
|
|
torch.manual_seed(1)
|
|
|
|
with torch.device("cpu"):
|
|
cpu_t1, cpu_t2 = generate_tensor()
|
|
|
|
torch.manual_seed(2)
|
|
|
|
with torch.device("lazy"):
|
|
lazy_t1, lazy_t2 = generate_tensor()
|
|
|
|
torch._lazy.mark_step()
|
|
|
|
assert torch.allclose(cpu_t1, lazy_t1.to("cpu")), (
|
|
f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
|
|
)
|
|
assert torch.allclose(cpu_t2, lazy_t2.to("cpu")), (
|
|
f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
|
|
)
|
|
|
|
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
|
|
def test_generator_causes_multiple_compiles(self):
|
|
"""
|
|
Test that inserting generators with different seed caused recompile
|
|
"""
|
|
|
|
def generate_tensor(seed):
|
|
t = torch.tensor(1.0)
|
|
g = torch.Generator()
|
|
g.manual_seed(seed)
|
|
t.uniform_(-1, 1, generator=g)
|
|
return t
|
|
|
|
metrics.reset()
|
|
|
|
with torch.device("lazy"):
|
|
t = generate_tensor(1)
|
|
torch._lazy.mark_step()
|
|
|
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
|
assert uncached_compile == 1, (
|
|
f"Expected 1 uncached compiles, got {uncached_compile}"
|
|
)
|
|
|
|
t = generate_tensor(2)
|
|
torch._lazy.mark_step()
|
|
|
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
|
assert uncached_compile == 2, (
|
|
f"Expected 2 uncached compiles, got {uncached_compile}"
|
|
)
|
|
|
|
t = generate_tensor(1) # noqa: F841
|
|
torch._lazy.mark_step()
|
|
|
|
uncached_compile = metrics.counter_value("UncachedCompile")
|
|
assert uncached_compile == 2, (
|
|
f"Expected 2 uncached compiles, got {uncached_compile}"
|
|
)
|
|
cached_compile = metrics.counter_value("CachedCompile")
|
|
assert cached_compile == 1, (
|
|
f"Expected 1 cached compile, got {cached_compile}"
|
|
)
|
|
|
|
metrics.reset()
|
|
|
|
latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
|
|
assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
|
|
assert "aten::uniform" in latest_graph
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|