mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[nn] add + operator for torch.nn.Sequential to concatenate (#81170)
				
					
				
			Fixes #78512 #### TODO - [x] add tests cc @kshitij12345! Pull Request resolved: https://github.com/pytorch/pytorch/pull/81170 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							8740c68c41
						
					
				
				
					commit
					3da8c909da
				
			@ -1580,6 +1580,15 @@ class TestNN(NNTestCase):
 | 
			
		||||
        del n[1::2]
 | 
			
		||||
        self.assertEqual(n, nn.Sequential(l1, l3))
 | 
			
		||||
 | 
			
		||||
    def test_Sequential_add(self):
 | 
			
		||||
        l1 = nn.Linear(1, 2)
 | 
			
		||||
        l2 = nn.Linear(2, 3)
 | 
			
		||||
        l3 = nn.Linear(3, 4)
 | 
			
		||||
        l4 = nn.Linear(4, 5)
 | 
			
		||||
        n = nn.Sequential(l1, l2)
 | 
			
		||||
        other = nn.Sequential(l3, l4)
 | 
			
		||||
        self.assertEqual(n + other, nn.Sequential(l1, l2, l3, l4))
 | 
			
		||||
 | 
			
		||||
    def test_Sequential_append(self):
 | 
			
		||||
        l1 = nn.Linear(10, 20)
 | 
			
		||||
        l2 = nn.Linear(20, 30)
 | 
			
		||||
 | 
			
		||||
@ -122,6 +122,19 @@ class Sequential(Module):
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return len(self._modules)
 | 
			
		||||
 | 
			
		||||
    def __add__(self, other) -> 'Sequential':
 | 
			
		||||
        if isinstance(other, Sequential):
 | 
			
		||||
            ret = Sequential()
 | 
			
		||||
            for layer in self:
 | 
			
		||||
                ret.append(layer)
 | 
			
		||||
            for layer in other:
 | 
			
		||||
                ret.append(layer)
 | 
			
		||||
            return ret
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError('add operator supports only objects '
 | 
			
		||||
                             'of Sequential class, but {} is given.'.format(
 | 
			
		||||
                                 str(type(other))))
 | 
			
		||||
 | 
			
		||||
    @_copy_to_script_wrapper
 | 
			
		||||
    def __dir__(self):
 | 
			
		||||
        keys = super(Sequential, self).__dir__()
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user