Workarounds for cudnn_batch_norm with TorchRefsNvfuserCapabilityMode (#86796)

This PR adds workarounds to support AOT Autograd's graphs containing `aten.cudnn_batch_norm` and `aten.cudnn_batch_norm_backward` with `TorchRefsNvfuserCapabilityMode`.

The problem with the decomposition of `aten.cudnn_batch_norm` is that it uses a `new_empty` call that is not supported by nvFuser and we are conservative with lowering functions to nvprims by default.

The problem with the decomposition of `aten.cudnn_batch_norm_backward` is described here https://github.com/pytorch/pytorch/pull/86115#issue-1394883782, but changing the decomposition directly in that PR makes many tests fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86796
Approved by: https://github.com/mruberry
This commit is contained in:
Ivan Yashchuk
2022-10-17 18:46:28 +00:00
committed by PyTorch MergeBot
parent 33343def0b
commit 31931515bc
2 changed files with 155 additions and 9 deletions

View File

@ -611,6 +611,63 @@ class TestPrims(TestCase):
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float64)
def test_cudnn_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
# This test verifies that cudnn_batch_norm is translated into nvprims
# and can be executed with nvFuser
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import (
sample_inputs_native_batch_norm,
)
samples = sample_inputs_native_batch_norm(
None, device, dtype, requires_grad=False
)
for sample in samples:
if sample.input.numel() == 0:
continue
def func(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.ops.aten.cudnn_batch_norm.default(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(sample.input, *sample.args)
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
includes_aten_batch_norm = any(
torch.ops.aten.cudnn_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_batch_norm)
includes_nvprims_batch_norm = any(
torch.ops.nvprims.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_batch_norm)
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA
@dtypes(torch.float32, torch.float16)
@ -624,23 +681,34 @@ class TestPrims(TestCase):
sample = next(samples_iter)
grad = torch.randn_like(sample.input)
def func(grad, input, weight, rm, rv, eps, train):
def func1(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.native_batch_norm_backward.default(
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
)
def func2(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.cudnn_batch_norm_backward.default(
input, grad, weight, rm, rv, rm, rv, eps, grad
)
args = sample.args
kwargs = sample.kwargs
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)
for func in (func1, func2):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)
all_nvprims = all(
str(node.target).startswith("nvprims") for node in call_function_nodes
)
self.assertTrue(all_nvprims)
@onlyCUDA
@skipCUDAIfRocm

View File

@ -267,6 +267,68 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
)
# TODO: remove this once version from _decomp/decompositions.py is working
# with this context manager
# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm(
self,
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
):
a, b, c = torch.ops.nvprims.native_batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
)
if training:
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
return (
a,
weight.new_zeros((0,)),
weight.new_zeros((0,)),
input.new_zeros((0,), dtype=torch.uint8),
)
# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm_backward(
self,
input,
grad_output,
weight,
running_mean,
running_var,
save_mean,
save_var,
epsilon,
reserveSpace,
):
func = torch._decomp.decomposition_table[
torch.ops.aten.native_batch_norm_backward.default
]
return func(
grad_output,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
True,
epsilon,
[True, True, True],
)
def _is_var_mean(self, func):
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
(
@ -313,6 +375,22 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
if (
orig_func == torch.ops.aten.cudnn_batch_norm.default
or orig_func == torch.ops.aten.cudnn_batch_norm
):
with self:
return self._cudnn_batch_norm(*args, **kwargs)
# A workaround for AOT Autograd graphs
# See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
if (
orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
or orig_func == torch.ops.aten.cudnn_batch_norm_backward
):
with self:
return self._cudnn_batch_norm_backward(*args, **kwargs)
if self._is_view_or_reshape(orig_func):
a, *shape = args
shape = torch._prims_common.extract_shape_from_varargs(