mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
Summary: Action based on https://github.com/pytorch/pytorch/issues/66232 cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/66797 Reviewed By: gchanan Differential Revision: D31761389 Pulled By: janeyx99 fbshipit-source-id: c27c9ab4acec1eb71d5edd4538cd113b770dfc6c
143 lines
3.0 KiB
Python
143 lines
3.0 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# Copyright 2019 Kakao Brain
|
|
#
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import pytest
|
|
import torch
|
|
import torch.cuda
|
|
|
|
from torch.distributed.pipeline.sync.microbatch import Batch, check, gather, scatter
|
|
|
|
|
|
def test_batch_atomic():
|
|
x = torch.tensor(42)
|
|
b = Batch(x)
|
|
|
|
assert b.atomic
|
|
|
|
assert b.tensor is x
|
|
with pytest.raises(AttributeError):
|
|
b.tensors
|
|
|
|
assert list(b) == [x]
|
|
assert len(b) == 1
|
|
assert b[0] is x
|
|
|
|
|
|
def test_batch_non_atomic():
|
|
x, y = torch.tensor(42), torch.tensor(21)
|
|
b = Batch((x, y))
|
|
|
|
assert not b.atomic
|
|
|
|
with pytest.raises(AttributeError):
|
|
b.tensor
|
|
|
|
assert list(b) == [x, y]
|
|
assert len(b) == 2
|
|
assert b[0] is x
|
|
assert b[1] is y
|
|
|
|
|
|
def test_batch_call():
|
|
a = Batch(torch.tensor(42))
|
|
b = Batch((torch.tensor(42), torch.tensor(21)))
|
|
|
|
def f(x):
|
|
return x
|
|
|
|
def g(x, y):
|
|
return x, y
|
|
|
|
assert a.call(f).atomic
|
|
assert not b.call(g).atomic
|
|
|
|
|
|
def test_batch_setitem_by_index():
|
|
a = Batch(torch.tensor(42))
|
|
b = Batch((torch.tensor(42), torch.tensor(21)))
|
|
|
|
a[0] = torch.tensor(0)
|
|
b[0] = torch.tensor(0)
|
|
|
|
assert a.atomic
|
|
assert a[0].item() == 0
|
|
|
|
assert not b.atomic
|
|
assert len(b) == 2
|
|
assert b[0].item() == 0
|
|
assert b[1].item() == 21
|
|
|
|
|
|
def test_batch_setitem_by_slice():
|
|
a = Batch(torch.tensor(42))
|
|
b = Batch((torch.tensor(42), torch.tensor(21)))
|
|
|
|
a[:] = (torch.tensor(0),)
|
|
b[:] = (torch.tensor(0),)
|
|
|
|
assert a.atomic
|
|
assert a[0].item() == 0
|
|
|
|
assert not b.atomic
|
|
assert len(b) == 1
|
|
assert b[0].item() == 0
|
|
|
|
|
|
def test_check():
|
|
check(torch.device("cpu"), torch.tensor(42))
|
|
check(torch.device("cpu"), torch.tensor(4), torch.tensor(2))
|
|
|
|
with pytest.raises(TypeError):
|
|
check(torch.device("cpu"), 42)
|
|
|
|
with pytest.raises(TypeError):
|
|
check(torch.device("cpu"), "str")
|
|
|
|
with pytest.raises(TypeError):
|
|
check(torch.device("cpu"), (torch.tensor(4), 2))
|
|
|
|
|
|
def test_gather_tensors():
|
|
a = torch.zeros(1, 1)
|
|
b = torch.zeros(1, 1)
|
|
|
|
ab = gather([Batch(a), Batch(b)])
|
|
|
|
assert ab.size() == (2, 1)
|
|
|
|
|
|
def test_gather_tuples():
|
|
a = (torch.zeros(1, 1), torch.zeros(2, 2))
|
|
b = (torch.zeros(1, 1), torch.zeros(2, 2))
|
|
|
|
ab = gather([Batch(a), Batch(b)])
|
|
|
|
assert isinstance(ab, tuple)
|
|
assert ab[0].size() == (2, 1)
|
|
assert ab[1].size() == (4, 2)
|
|
|
|
|
|
def test_scatter_tensor():
|
|
ab = torch.zeros(2, 1)
|
|
|
|
a, b = scatter(ab, chunks=2)
|
|
|
|
assert a.tensor.size() == (1, 1)
|
|
assert b.tensor.size() == (1, 1)
|
|
|
|
|
|
def test_scatter_multiple_tensors():
|
|
ab = (torch.zeros(2, 1), torch.zeros(4, 2))
|
|
|
|
a, b = scatter(*ab, chunks=2)
|
|
|
|
assert list(a)[0].size() == (1, 1)
|
|
assert list(b)[0].size() == (1, 1)
|
|
assert list(a)[1].size() == (2, 2)
|
|
assert list(b)[1].size() == (2, 2)
|