mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Part of my effort to move everything to pytest and decrease the number of testrunner frameworks in ci Gives xmls but they might look a weird b/c module level tests vs tests in classes. Doesn't give skip/disable test infra because those are tied to classes. (for future ref, could either put tests in classes or move the check_if_enable stuff into a pytest hook) Tested in CI and checked that the same number of tests are run Pull Request resolved: https://github.com/pytorch/pytorch/pull/95659 Approved by: https://github.com/huydhn
58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import torch
|
|
|
|
from torch.distributed.pipeline.sync.phony import get_phony
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
|
|
def test_phony_size():
|
|
p = get_phony(torch.device("cpu"), requires_grad=False)
|
|
assert p.size() == (0,)
|
|
|
|
|
|
def test_phony_requires_grad():
|
|
p1 = get_phony(torch.device("cpu"), requires_grad=True)
|
|
p2 = get_phony(torch.device("cpu"), requires_grad=False)
|
|
assert p1.requires_grad
|
|
assert not p2.requires_grad
|
|
|
|
|
|
def test_cached_phony():
|
|
p1 = get_phony(torch.device("cpu"), requires_grad=True)
|
|
p2 = get_phony(torch.device("cpu"), requires_grad=True)
|
|
assert p1 is p2
|
|
|
|
p3 = get_phony(torch.device("cpu"), requires_grad=False)
|
|
p4 = get_phony(torch.device("cpu"), requires_grad=False)
|
|
assert p3 is p4
|
|
|
|
assert p1 is not p3
|
|
|
|
|
|
def test_phony_in_autograd_function():
|
|
class Phonify(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
phony = get_phony(input.device, requires_grad=False)
|
|
return phony.detach()
|
|
|
|
x = torch.rand(1, requires_grad=True)
|
|
|
|
p1 = Phonify.apply(x)
|
|
p2 = get_phony(torch.device("cpu"), requires_grad=True)
|
|
|
|
assert p1 is not p2
|
|
assert p1.grad_fn is not None
|
|
assert p2.grad_fn is None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|