OpInfo JIT op.output_func handling support (#50775)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50775

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D25964541

Pulled By: Lilyjjo

fbshipit-source-id: 8cf1ee9191d526cc46ae283f38c2d64bd60afdb2
This commit is contained in:
Lillian Johnson
2021-01-27 15:01:46 -08:00
committed by Facebook GitHub Bot
parent eaf5ca09dc
commit 3b6f30824c
5 changed files with 32 additions and 36 deletions

View File

@ -278,17 +278,15 @@ class TestCommon(JitCommonTestCase):
# autodiff support. Context manager forces the graph to contain
# DifferentiableGraph nodes if they are present
with disable_autodiff_subgraph_inlining():
def fn(*inputs, **kwargs):
output = func(*inputs, **kwargs)
return op.output_func(output)
# Check scripted forward, grad, and grad grad
script_fn = create_script_fn(self, name, func_type, op.output_func)
script_fn = create_script_fn(self, name, func_type)
check_against_reference(self,
script_fn,
fn,
func,
op.output_func,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=not test_backward)
@ -297,7 +295,8 @@ class TestCommon(JitCommonTestCase):
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
fn,
func,
op.output_func,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=not test_backward)