[nn] add + operator for torch.nn.Sequential to concatenate (#81170)

Fixes #78512

#### TODO
- [x] add tests

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81170
Approved by: https://github.com/albanD
This commit is contained in:
Khushi Agrawal
2022-07-11 17:49:56 +00:00
committed by PyTorch MergeBot
parent 8740c68c41
commit 3da8c909da
2 changed files with 22 additions and 0 deletions

View File

@ -1580,6 +1580,15 @@ class TestNN(NNTestCase):
del n[1::2]
self.assertEqual(n, nn.Sequential(l1, l3))
def test_Sequential_add(self):
l1 = nn.Linear(1, 2)
l2 = nn.Linear(2, 3)
l3 = nn.Linear(3, 4)
l4 = nn.Linear(4, 5)
n = nn.Sequential(l1, l2)
other = nn.Sequential(l3, l4)
self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4))
def test_Sequential_append(self):
l1 = nn.Linear(10, 20)
l2 = nn.Linear(20, 30)

View File

@ -122,6 +122,19 @@ class Sequential(Module):
def __len__(self) -> int:
return len(self._modules)
def __add__(self, other) -> 'Sequential':
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
ret.append(layer)
for layer in other:
ret.append(layer)
return ret
else:
raise ValueError('add operator supports only objects '
'of Sequential class, but {} is given.'.format(
str(type(other))))
@_copy_to_script_wrapper
def __dir__(self):
keys = super(Sequential, self).__dir__()