mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[jit] allow compilation using optional modules (#32539)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32539 Before: if something in `_modules` was `None`, we would barf. This is incorrect because it's allowed for users to put `None` there, in case a module is optional. This case ought to be handled correctly during scripting. Fixes https://github.com/pytorch/pytorch/issues/32469 Test Plan: Imported from OSS Differential Revision: D19552346 Pulled By: suo fbshipit-source-id: aba7fdc19fd84d195c81cdaca8a75013a8626a8b
This commit is contained in:
		
				
					committed by
					
						
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							7d0f0b62de
						
					
				
				
					commit
					ef5637f85e
				
			@ -650,3 +650,19 @@ class TestRecursiveScript(JitTestCase):
 | 
			
		||||
                return self.encoder(input)
 | 
			
		||||
 | 
			
		||||
        self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
 | 
			
		||||
 | 
			
		||||
    def test_optional_module(self):
 | 
			
		||||
        class Dummy(nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super(Dummy, self).__init__()
 | 
			
		||||
                self.foo = nn.Linear(2, 2)
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                if self.foo is not None:
 | 
			
		||||
                    return self.foo(x)
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        mod = Dummy()
 | 
			
		||||
        self.checkModule(mod, (torch.rand(2, 2),))
 | 
			
		||||
        mod.foo = None
 | 
			
		||||
        self.checkModule(mod, (torch.rand(2, 2),))
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user