mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Workarounds for cudnn_batch_norm with TorchRefsNvfuserCapabilityMode (#86796)
This PR adds workarounds to support AOT Autograd's graphs containing `aten.cudnn_batch_norm` and `aten.cudnn_batch_norm_backward` with `TorchRefsNvfuserCapabilityMode`. The problem with the decomposition of `aten.cudnn_batch_norm` is that it uses a `new_empty` call that is not supported by nvFuser and we are conservative with lowering functions to nvprims by default. The problem with the decomposition of `aten.cudnn_batch_norm_backward` is described here https://github.com/pytorch/pytorch/pull/86115#issue-1394883782, but changing the decomposition directly in that PR makes many tests fail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86796 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
33343def0b
commit
31931515bc
@ -611,6 +611,63 @@ class TestPrims(TestCase):
|
||||
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_cudnn_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
# This test verifies that cudnn_batch_norm is translated into nvprims
|
||||
# and can be executed with nvFuser
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_native_batch_norm,
|
||||
)
|
||||
|
||||
samples = sample_inputs_native_batch_norm(
|
||||
None, device, dtype, requires_grad=False
|
||||
)
|
||||
for sample in samples:
|
||||
if sample.input.numel() == 0:
|
||||
continue
|
||||
|
||||
def func(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return torch.ops.aten.cudnn_batch_norm.default(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(sample.input, *sample.args)
|
||||
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
includes_aten_batch_norm = any(
|
||||
torch.ops.aten.cudnn_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_batch_norm)
|
||||
|
||||
includes_nvprims_batch_norm = any(
|
||||
torch.ops.nvprims.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_batch_norm)
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
@ -624,23 +681,34 @@ class TestPrims(TestCase):
|
||||
sample = next(samples_iter)
|
||||
grad = torch.randn_like(sample.input)
|
||||
|
||||
def func(grad, input, weight, rm, rv, eps, train):
|
||||
def func1(grad, input, weight, rm, rv, eps, train):
|
||||
return torch.ops.aten.native_batch_norm_backward.default(
|
||||
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
|
||||
)
|
||||
|
||||
def func2(grad, input, weight, rm, rv, eps, train):
|
||||
return torch.ops.aten.cudnn_batch_norm_backward.default(
|
||||
input, grad, weight, rm, rv, rm, rv, eps, grad
|
||||
)
|
||||
|
||||
args = sample.args
|
||||
kwargs = sample.kwargs
|
||||
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(*all_args)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_batch_norm_backward = any(
|
||||
torch.ops.aten.native_batch_norm_backward.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_batch_norm_backward)
|
||||
for func in (func1, func2):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(*all_args)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_batch_norm_backward = any(
|
||||
torch.ops.aten.native_batch_norm_backward.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_batch_norm_backward)
|
||||
all_nvprims = all(
|
||||
str(node.target).startswith("nvprims") for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(all_nvprims)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
|
@ -267,6 +267,68 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
|
||||
)
|
||||
|
||||
# TODO: remove this once version from _decomp/decompositions.py is working
|
||||
# with this context manager
|
||||
# This is a workaround for AOT Autograd graphs
|
||||
def _cudnn_batch_norm(
|
||||
self,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
):
|
||||
a, b, c = torch.ops.nvprims.native_batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
)
|
||||
if training:
|
||||
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
|
||||
return (
|
||||
a,
|
||||
weight.new_zeros((0,)),
|
||||
weight.new_zeros((0,)),
|
||||
input.new_zeros((0,), dtype=torch.uint8),
|
||||
)
|
||||
|
||||
# This is a workaround for AOT Autograd graphs
|
||||
def _cudnn_batch_norm_backward(
|
||||
self,
|
||||
input,
|
||||
grad_output,
|
||||
weight,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_var,
|
||||
epsilon,
|
||||
reserveSpace,
|
||||
):
|
||||
func = torch._decomp.decomposition_table[
|
||||
torch.ops.aten.native_batch_norm_backward.default
|
||||
]
|
||||
return func(
|
||||
grad_output,
|
||||
input,
|
||||
weight,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_var,
|
||||
True,
|
||||
epsilon,
|
||||
[True, True, True],
|
||||
)
|
||||
|
||||
def _is_var_mean(self, func):
|
||||
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
|
||||
(
|
||||
@ -313,6 +375,22 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
if self._is_var_mean(orig_func):
|
||||
return torch.ops.nvprims.var_mean(*args, **kwargs)
|
||||
|
||||
if (
|
||||
orig_func == torch.ops.aten.cudnn_batch_norm.default
|
||||
or orig_func == torch.ops.aten.cudnn_batch_norm
|
||||
):
|
||||
with self:
|
||||
return self._cudnn_batch_norm(*args, **kwargs)
|
||||
|
||||
# A workaround for AOT Autograd graphs
|
||||
# See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
|
||||
if (
|
||||
orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
|
||||
or orig_func == torch.ops.aten.cudnn_batch_norm_backward
|
||||
):
|
||||
with self:
|
||||
return self._cudnn_batch_norm_backward(*args, **kwargs)
|
||||
|
||||
if self._is_view_or_reshape(orig_func):
|
||||
a, *shape = args
|
||||
shape = torch._prims_common.extract_shape_from_varargs(
|
||||
|
Reference in New Issue
Block a user