Remove empty_like+fill from AOT Autograd graphs for nvFuser (#86908)

AOT Autograd records C++ code `1 - tensor` as a sequence of empty_like, fill, and sub (see https://github.com/pytorch/pytorch/issues/86612).

Both empty_like and fill are not supported yet. This PR is a workaround for enabling fusions of `silu_backward`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86908
Approved by: https://github.com/ngimel
This commit is contained in:
Ivan Yashchuk
2022-10-14 19:49:39 +00:00
committed by PyTorch MergeBot
parent 56a744bf47
commit fc3afc8407
2 changed files with 77 additions and 0 deletions

View File

@ -642,6 +642,55 @@ class TestPrims(TestCase):
)
self.assertFalse(includes_batch_norm_backward)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_silu_backward_no_filled_tensor(self, device, dtype):
# This test verifies a workaround for
# https://github.com/pytorch/pytorch/issues/86612
from torch.fx.experimental.proxy_tensor import make_fx
from functorch import functionalize
from torch._prims.nvfuser_executor import _remove_empty_like_fill
from torch._prims.context import TorchRefsNvfuserCapabilityMode
def func(a):
out = torch.nn.functional.silu(a)
grad = torch.ones_like(out)
return torch.autograd.grad([out], [a], [grad])
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
a = make_arg((3, 4))
gm = make_fx(func)(a)
# functionalize(gm) doesn't work with non-detached inputs
gm = make_fx(functionalize(gm))(a.detach())
# replace aten.sub with nvprims.sub
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(a)
# Check that the graph contains empty_like
any_aten_empty_like = any(
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
)
self.assertTrue(any_aten_empty_like)
any_aten_fill = any(
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
)
self.assertTrue(any_aten_fill)
# Now remove the empty_like and fill
gm = _remove_empty_like_fill(gm)
any_aten_empty_like = any(
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
)
self.assertFalse(any_aten_empty_like)
any_aten_fill = any(
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
)
self.assertFalse(any_aten_fill)
self.assertEqual(gm(a), func(a))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)