Dictionarize check_inputs coming from trace

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20813

Differential Revision: D15466836

Pulled By: Krovatkin

fbshipit-source-id: ffdb418592b76dc67c65c59f4dc7303f08734f97
This commit is contained in:
Nikolay Korovaiko
2019-05-23 10:58:51 -07:00
committed by Facebook Github Bot
parent 2c556a9489
commit 31e2d20c5e
2 changed files with 9 additions and 2 deletions

View File

@ -3409,6 +3409,7 @@ def foo(x):
module = torch.jit.trace_module(n, inputs, True, True, check_inputs)
module = torch.jit.trace(n.forward, example_forward_input)
module = torch.jit.trace(n.forward, example_forward_input, True, True, [example_forward_input])
with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"):
module = torch.jit.trace(n.weighted_kernel_sum, inputs)

View File

@ -638,6 +638,11 @@ def make_module(mod, _module_class, executor_options):
_module_class = TopLevelTracedModule
return _module_class(mod, **executor_options)
def wrap_check_inputs(check_inputs):
if check_inputs is None:
return None
return [{'forward' : c} for c in check_inputs]
def trace(func,
example_inputs,
@ -721,13 +726,14 @@ def trace(func,
if isinstance(func, torch.nn.Module):
return trace_module(func, {'forward': example_inputs}, optimize,
check_trace, check_inputs,
check_trace, wrap_check_inputs(check_inputs),
check_tolerance, _force_outplace, _module_class)
if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and
func.__name__ == 'forward'):
return trace_module(func.__self__, {'forward': example_inputs}, optimize,
check_trace, check_inputs,
check_trace, wrap_check_inputs(check_inputs),
check_tolerance, _force_outplace, _module_class)
executor_options = {'optimize': bool(optimize)}