[JIT] Exit Transform Rewrite (#38282)

Summary:
After an early return, we conditionalize all further execution. This means that currently the pattern of
`if return elif return elif return` generates better code than `if return if return if return`. It's obviously not good to have semantically equivalent code generate worse IR, so we should rewrite the graph to handle this case. This came up in https://github.com/pytorch/pytorch/pull/37171

```
torch.jit.script
def test_foo(x: bool, y: bool):
    if x:
        return 1
    return 2
print(test_foo.code)
```
generates:
```
def test_foo(x: bool,
    y: bool) -> int:
  _0 = uninitialized(int)
  if x:
    _1, _2 = True, 1
  else:
    _1, _2 = False, _0
  if _1:
    _3 = _2
  else:
    _3 = 2
  return _3
```
while
```
torch.jit.script
def test_foo(x: bool, y: bool):
    if x:
        return 1
    else:
        return 2
print(test_foo.code)
```
generates:
```
def test_foo(x: bool,
    y: bool) -> int:
  if x:
    _0 = 1
  else:
    _0 = 2
  return _0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38282

Differential Revision: D21576733

Pulled By: eellison

fbshipit-source-id: 80cf1ad7fbda6d8d58557abbfb21c90eafae7488
This commit is contained in:
Elias Ellison
2020-05-15 12:20:13 -07:00
committed by Facebook GitHub Bot
parent 62afc2d63d
commit daa85cfe2e
5 changed files with 152 additions and 21 deletions

View File

@ -473,6 +473,7 @@ class TestFuser(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
@torch.jit._disable_emit_hooks_decorator
@_inline_everything
def test_fuse_decompose_normalization(self):
class ResLike(torch.jit.ScriptModule):