mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove empty_like+fill from AOT Autograd graphs for nvFuser (#86908)
AOT Autograd records C++ code `1 - tensor` as a sequence of empty_like, fill, and sub (see https://github.com/pytorch/pytorch/issues/86612). Both empty_like and fill are not supported yet. This PR is a workaround for enabling fusions of `silu_backward`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86908 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
56a744bf47
commit
fc3afc8407
@ -642,6 +642,55 @@ class TestPrims(TestCase):
|
||||
)
|
||||
self.assertFalse(includes_batch_norm_backward)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_silu_backward_no_filled_tensor(self, device, dtype):
|
||||
# This test verifies a workaround for
|
||||
# https://github.com/pytorch/pytorch/issues/86612
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from functorch import functionalize
|
||||
from torch._prims.nvfuser_executor import _remove_empty_like_fill
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
def func(a):
|
||||
out = torch.nn.functional.silu(a)
|
||||
grad = torch.ones_like(out)
|
||||
return torch.autograd.grad([out], [a], [grad])
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
|
||||
a = make_arg((3, 4))
|
||||
gm = make_fx(func)(a)
|
||||
# functionalize(gm) doesn't work with non-detached inputs
|
||||
gm = make_fx(functionalize(gm))(a.detach())
|
||||
|
||||
# replace aten.sub with nvprims.sub
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(a)
|
||||
|
||||
# Check that the graph contains empty_like
|
||||
any_aten_empty_like = any(
|
||||
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(any_aten_empty_like)
|
||||
any_aten_fill = any(
|
||||
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(any_aten_fill)
|
||||
|
||||
# Now remove the empty_like and fill
|
||||
gm = _remove_empty_like_fill(gm)
|
||||
any_aten_empty_like = any(
|
||||
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(any_aten_empty_like)
|
||||
any_aten_fill = any(
|
||||
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(any_aten_fill)
|
||||
self.assertEqual(gm(a), func(a))
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
|
Reference in New Issue
Block a user