mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Dictionarize check_inputs coming from trace
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20813 Differential Revision: D15466836 Pulled By: Krovatkin fbshipit-source-id: ffdb418592b76dc67c65c59f4dc7303f08734f97
This commit is contained in:
committed by
Facebook Github Bot
parent
2c556a9489
commit
31e2d20c5e
@ -3409,6 +3409,7 @@ def foo(x):
|
||||
module = torch.jit.trace_module(n, inputs, True, True, check_inputs)
|
||||
|
||||
module = torch.jit.trace(n.forward, example_forward_input)
|
||||
module = torch.jit.trace(n.forward, example_forward_input, True, True, [example_forward_input])
|
||||
with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"):
|
||||
module = torch.jit.trace(n.weighted_kernel_sum, inputs)
|
||||
|
||||
|
@ -638,6 +638,11 @@ def make_module(mod, _module_class, executor_options):
|
||||
_module_class = TopLevelTracedModule
|
||||
return _module_class(mod, **executor_options)
|
||||
|
||||
def wrap_check_inputs(check_inputs):
|
||||
if check_inputs is None:
|
||||
return None
|
||||
|
||||
return [{'forward' : c} for c in check_inputs]
|
||||
|
||||
def trace(func,
|
||||
example_inputs,
|
||||
@ -721,13 +726,14 @@ def trace(func,
|
||||
|
||||
if isinstance(func, torch.nn.Module):
|
||||
return trace_module(func, {'forward': example_inputs}, optimize,
|
||||
check_trace, check_inputs,
|
||||
check_trace, wrap_check_inputs(check_inputs),
|
||||
check_tolerance, _force_outplace, _module_class)
|
||||
|
||||
if (hasattr(func, '__self__') and isinstance(func.__self__, torch.nn.Module) and
|
||||
func.__name__ == 'forward'):
|
||||
|
||||
return trace_module(func.__self__, {'forward': example_inputs}, optimize,
|
||||
check_trace, check_inputs,
|
||||
check_trace, wrap_check_inputs(check_inputs),
|
||||
check_tolerance, _force_outplace, _module_class)
|
||||
|
||||
executor_options = {'optimize': bool(optimize)}
|
||||
|
Reference in New Issue
Block a user