mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
OpInfo JIT op.output_func handling support (#50775)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50775 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D25964541 Pulled By: Lilyjjo fbshipit-source-id: 8cf1ee9191d526cc46ae283f38c2d64bd60afdb2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eaf5ca09dc
commit
3b6f30824c
@ -278,17 +278,15 @@ class TestCommon(JitCommonTestCase):
|
||||
# autodiff support. Context manager forces the graph to contain
|
||||
# DifferentiableGraph nodes if they are present
|
||||
with disable_autodiff_subgraph_inlining():
|
||||
def fn(*inputs, **kwargs):
|
||||
output = func(*inputs, **kwargs)
|
||||
return op.output_func(output)
|
||||
|
||||
|
||||
# Check scripted forward, grad, and grad grad
|
||||
script_fn = create_script_fn(self, name, func_type, op.output_func)
|
||||
script_fn = create_script_fn(self, name, func_type)
|
||||
|
||||
check_against_reference(self,
|
||||
script_fn,
|
||||
fn,
|
||||
func,
|
||||
op.output_func,
|
||||
(*sample.input,) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not test_backward)
|
||||
@ -297,7 +295,8 @@ class TestCommon(JitCommonTestCase):
|
||||
traced_fn = create_traced_fn(self, variant)
|
||||
check_against_reference(self,
|
||||
traced_fn,
|
||||
fn,
|
||||
func,
|
||||
op.output_func,
|
||||
(*sample.input,) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not test_backward)
|
||||
|
Reference in New Issue
Block a user