Add nvFuser support for torch.native_batch_norm (#85562)

This PR adds nvFuser's implementation for batch_norm as there's no reference yet (https://github.com/pytorch/pytorch/pull/81191) and no in-place copy support (https://github.com/pytorch/pytorch/pull/84545).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85562
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
This commit is contained in:
Ivan Yashchuk
2022-10-03 15:03:08 +00:00
committed by PyTorch MergeBot
parent d28a882319
commit 68a6113248
11 changed files with 371 additions and 4 deletions

View File

@ -395,6 +395,7 @@ class TestOperators(TestCase):
skip('nn.functional.max_unpool1d'), # fails everywhere except on mac
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
xfail("native_batch_norm"),
xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented
}))
@ -643,6 +644,7 @@ class TestOperators(TestCase):
xfail("nn.functional.batch_norm", 'without_cudnn'),
# view doesn't work on sparse
xfail("to_sparse"),
xfail("native_batch_norm"),
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@ -725,6 +727,7 @@ class TestOperators(TestCase):
# ---------------------------- BUGS ------------------------------------
# All of the following are bugs and need to be fixed
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
skip("native_batch_norm"),
xfail('__getitem__', ''), # dynamic error
xfail('linalg.eig'), # Uses aten::allclose
xfail('linalg.householder_product'), # needs select_scatter
@ -833,6 +836,7 @@ class TestOperators(TestCase):
# erroring because running_mean and running_var aren't differentiable
xfail('nn.functional.batch_norm'),
xfail('nn.functional.batch_norm', 'without_cudnn'),
xfail("native_batch_norm"),
# ----------------------------------------------------------------------
}
@ -1030,6 +1034,7 @@ class TestOperators(TestCase):
xfail('linalg.vecdot', ''),
xfail('segment_reduce', 'lengths'),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
}))
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
if not op.supports_autograd:
@ -1095,6 +1100,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout3d', ''),
xfail('as_strided_scatter', ''),
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
}))
def test_vjpvmap(self, device, dtype, op):
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@ -1338,6 +1344,10 @@ class TestOperators(TestCase):
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('to_sparse'), # Forward AD not implemented and no decomposition
xfail('view_as_complex'), # RuntimeError: Tensor must have a last dimension with stride 1
# RuntimeError: Batch norm got a batched tensor as
# input while the running_mean or running_var, which will be updated in
# place, were not batched.
xfail("native_batch_norm"),
}))
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})

View File

@ -3287,6 +3287,7 @@ class TestVmapOperatorsOpInfo(TestCase):
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
xfail('cat'),
xfail('native_batch_norm'),
}))
def test_vmap_exhaustive(self, device, dtype, op):
# needs to be fixed
@ -3306,6 +3307,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('cat'),
xfail('complex'),
xfail('copysign'),
xfail('native_batch_norm'),
xfail('histogram'),
xfail('index_fill'),
xfail('nansum'),

View File

@ -548,6 +548,69 @@ class TestPrims(TestCase):
self.assertFalse(node.target == torch.ops.prims.add.default)
self.assertFalse(node.target == torch.ops.aten.add.default)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float64)
def test_native_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
# This test verifies that native_batch_norm is translated into nvprims
# and can be executed with nvFuser
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import (
sample_inputs_native_batch_norm,
)
samples = sample_inputs_native_batch_norm(
None, device, dtype, requires_grad=False
)
batch_norms = [
torch.native_batch_norm,
torch.ops.aten.native_batch_norm,
torch.ops.aten.native_batch_norm.default,
torch.ops.nvprims.native_batch_norm.default,
]
for sample, batch_norm in product(samples, batch_norms):
if sample.input.numel() == 0:
continue
def func(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(sample.input, *sample.args)
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
includes_aten_batch_norm = any(
torch.ops.aten.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_batch_norm)
includes_nvprims_batch_norm = any(
torch.ops.nvprims.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_batch_norm)
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA
@dtypes(torch.float32, torch.float16)

View File

@ -265,6 +265,12 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
and "aten.var_mean" in str(func)
)
def _is_native_batch_norm(self, func):
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
func == torch.ops.aten.native_batch_norm.default
or func == torch.ops.aten.native_batch_norm
)
def _is_rand_like(self, func):
result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
@ -283,9 +289,14 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
if self._is_native_batch_norm(orig_func):
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
if self._is_rand_like(orig_func):
if len(kwargs) > 0:
warn("rand_like has ignored kwars!")
return torch.ops.nvprims.rand_like(*args)
# Then we use TorchRefsMode to interpret the rest
return super().__torch_function__(orig_func, types, args, kwargs)

View File

@ -136,6 +136,18 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
args, kwargs = self.fetch_args_kwargs_from_env(node)
args = [args[0], original_shape, args[1]]
return self.call_function(node.target, args, node.kwargs)
if node.target in [
torch.ops.nvprims.native_batch_norm,
torch.ops.nvprims.native_batch_norm.default,
]:
args, kwargs = self.fetch_args_kwargs_from_env(node)
assert len(args) == 8
training = args[5]
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
args = args[:5] + (training,) + args6_end
return node.target.impl_nvfuser(fd, *args, **kwargs)
return super().run_node(node)
def call_function(self, target, args, kwargs):

View File

@ -210,6 +210,29 @@ _nvfuser_impls["{fname}"] = _{fname}_nvfuser
)
def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
):
if weight is None:
weight = fd.define_null_tensor()
if bias is None:
bias = fd.define_null_tensor()
if running_mean is None:
running_mean = fd.define_null_tensor()
if running_var is None:
running_var = fd.define_null_tensor()
return fd.ops.batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
def _broadcast_in_dim_nvfuser(
fd: Any,
a: TensorLikeType,
@ -299,6 +322,7 @@ def _amin_nvfuser(
return fd.ops.min(a, dims, keep_dims)
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
_nvfuser_impls["transpose"] = _transpose_nvfuser
@ -312,6 +336,36 @@ _nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
def register_native_batch_norm():
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
name = "native_batch_norm"
nvprim.define(
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
+ "bool training, float momentum, float eps)"
+ " -> (Tensor, Tensor, Tensor)"
)
def _prim_impl(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_impl.impl(name, _prim_impl)
nvprim_autograd_impl.impl(
name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
)
prim_packet = torch.ops.nvprims.native_batch_norm
prim = prim_packet.default
for p in (prim_packet, prim):
p.__doc__ = "Computes batch normalization."
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_rand_like():
name = "rand_like"
@ -471,6 +525,7 @@ def register_var_mean():
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
register_native_batch_norm()
register_rand_like()
for name in nvprim_names:

View File

@ -587,8 +587,11 @@ ForwardNormResult batch_norm(
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
// During inference, mean/invstd output are empty tensors
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
// on CPU, but not on CUDA. We need to make sure we have the same
// behavior as with eager mode on CUDA.
mean = set(running_mean); // use set to avoid "trivial input forwarding NOT
// IMPLEMENTED" error
invstd = unbiased_invstd;
y = mul(x_sub_mean, invstd_bcast);
}

View File

@ -32,7 +32,12 @@ void FusionInterface::addOutput(Nvf::Val* output) const {
std::vector<at::Tensor> FusionInterface::execute(
const at::ArrayRef<c10::IValue>& inputs) const {
return fusionExecutorCachePtr()->runFusionWithInputs(inputs);
// aliasOutputToInput always adds Tensors as outputs that we don't want
// to return to the user. We need to remove them.
auto count_output_aliases = fusionPtr()->getOutputAliasIndices().size();
auto result = fusionExecutorCachePtr()->runFusionWithInputs(inputs);
result.erase(result.begin(), result.begin() + count_output_aliases);
return result;
}
Nvf::FusionGuard FusionInterface::guard() const {

View File

@ -15,11 +15,13 @@ namespace nvfuser {
enum class RecordType {
Base = 0,
Op,
BatchNormOp,
BroadcastOp,
CastOp,
Constant,
End,
Tensor,
NullTensor,
Output,
ReductionOp,
Scalar,
@ -895,6 +897,41 @@ struct TensorRecord : RecordFunctor {
bool is_cpu_;
};
struct NullTensorRecord : RecordFunctor {
NullTensorRecord(std::vector<State> _outputs)
: RecordFunctor(
{},
std::move(_outputs),
"null_tensor",
RecordType::NullTensor) {}
virtual ~NullTensorRecord() = default;
virtual RecordFunctor* clone() final {
return new NullTensorRecord(*this);
}
//! Nothing extra necessary in hash
//! Child specific hash function in lower 32 bits.
//! | 31 --------------------------------------- 0 |
//! | None |
virtual size_t hash() const final {
auto result = RecordFunctor::hash();
return result;
}
virtual bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (dynamic_cast<const NullTensorRecord*>(&other)) {
result = RecordFunctor::operator==(other);
}
return result;
}
virtual void operator()(FusionDefinition& fd) final {
Nvf::TensorView* tv = nullptr;
fd.setFusionState(outputs_.at(0).index, tv);
}
};
//! Specialized Record Functor for recording FusionDefinition outputs.
template <class OutputType>
@ -1313,6 +1350,70 @@ struct VarianceMeanOpRecord : NormOpRecord {
}
};
struct BatchNormOpRecord : RecordFunctor {
BatchNormOpRecord(
std::vector<State> args,
std::vector<State> outputs,
bool training,
bool channels_last)
: RecordFunctor(
std::move(args),
std::move(outputs),
"ops.batch_norm",
RecordType::BatchNormOp),
training_(training),
channels_last_(channels_last) {}
virtual ~BatchNormOpRecord() = default;
virtual RecordFunctor* clone() final {
return new BatchNormOpRecord(*this);
}
virtual bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (auto child_ptr = dynamic_cast<const BatchNormOpRecord*>(&other)) {
result = RecordFunctor::operator==(other);
result = result && (training_ == child_ptr->training_);
result = result && (channels_last_ == child_ptr->channels_last_);
}
return result;
}
virtual size_t hash() const final {
auto result = RecordFunctor::hash();
return result | (static_cast<size_t>(training_) << 28) |
(static_cast<size_t>(channels_last_) << 29);
}
void operator()(FusionDefinition& fd) final {
auto x = fd.getFusionState(args_.at(0).index)->as<Nvf::TensorView>();
auto weight = fd.getFusionState(args_.at(1).index)->as<Nvf::TensorView>();
auto bias = fd.getFusionState(args_.at(2).index)->as<Nvf::TensorView>();
auto running_mean =
fd.getFusionState(args_.at(3).index)->as<Nvf::TensorView>();
auto running_var =
fd.getFusionState(args_.at(4).index)->as<Nvf::TensorView>();
auto momentum = fd.getFusionState(args_.at(5).index)->as<Nvf::Val>();
auto eps = fd.getFusionState(args_.at(6).index)->as<Nvf::Val>();
auto output = Nvf::batch_norm(
x,
weight,
bias,
running_mean,
running_var,
training_,
momentum,
eps,
channels_last_);
fd.setFusionState(outputs_.at(0).index, output.output);
fd.setFusionState(outputs_.at(1).index, output.mean);
fd.setFusionState(outputs_.at(2).index, output.invstd);
}
private:
bool training_;
bool channels_last_;
};
} // namespace nvfuser
//! Creating the template specialized hash and equal_to functions for a

View File

@ -126,6 +126,16 @@ void initNvFuserPythonBindings(PyObject* module) {
self.defineRecord(new nvfuser::OutputRecord<Nvf::TensorView>(
{self.recordingState(output())}));
})
.def(
"define_null_tensor",
[](nvfuser::FusionDefinition& self) -> nvfuser::Tensor {
FUSER_PERF_SCOPE("FusionDefinition.define_null_tensor");
nvfuser::Tensor out = self.defineTensor();
self.defineRecord(
new nvfuser::NullTensorRecord({self.recordingState(out())}));
return out;
},
py::return_value_policy::reference)
.def(
"define_tensor",
[](nvfuser::FusionDefinition& self,
@ -1259,6 +1269,48 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("correction"),
py::arg("keepdim") = false,
py::return_value_policy::reference);
nvf_ops.def(
"batch_norm",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor x,
nvfuser::Tensor weight,
nvfuser::Tensor bias,
nvfuser::Tensor running_mean,
nvfuser::Tensor running_var,
bool training,
nvfuser::Scalar momentum,
nvfuser::Scalar eps,
bool channels_last) -> decltype(auto) {
FUSER_PERF_SCOPE("Operators.batch_norm");
nvfuser::FusionDefinition* fd = self.fusion_definition;
nvfuser::Tensor output = fd->defineTensor();
nvfuser::Tensor mean = fd->defineTensor();
nvfuser::Tensor invstd = fd->defineTensor();
fd->defineRecord(new nvfuser::BatchNormOpRecord(
{fd->recordingState(x()),
fd->recordingState(weight()),
fd->recordingState(bias()),
fd->recordingState(running_mean()),
fd->recordingState(running_var()),
fd->recordingState(momentum()),
fd->recordingState(eps())},
{fd->recordingState(output()),
fd->recordingState(mean()),
fd->recordingState(invstd())},
training,
channels_last));
return std::make_tuple(output, mean, invstd);
},
py::arg("x"),
py::arg("weight").none(true),
py::arg("bias").none(true),
py::arg("running_mean").none(true),
py::arg("running_var").none(true),
py::arg("training"),
py::arg("momentum"),
py::arg("eps"),
py::arg("channels_last") = false,
py::return_value_policy::reference);
nvf_ops.def(
"broadcast_in_dim",
[](nvfuser::FusionDefinition::Operators& self,

View File

@ -446,7 +446,22 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
# Test case for no optional kwargs
# running_mean and running_var are required in evaluation mode (training: False) but not in training mode
yield SampleInput(make_arg((1, 2, 3)), args=(None, None), kwargs={'training': True})
yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True})
def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
for sample in samples:
# torch.native_batch_norm does not support 0 numel tensors
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
if sample.input.numel() == 0:
continue
args = sample.args
training = sample.kwargs.get('training', True)
momentum = sample.kwargs.get('momentum', 0.5)
eps = sample.kwargs.get('eps', 1e-5)
yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -10640,6 +10655,30 @@ op_db: List[OpInfo] = [
# possibly because of the welford implementation.
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
)),
OpInfo('native_batch_norm',
aten_name='native_batch_norm',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
sample_inputs_func=sample_inputs_native_batch_norm,
skips=(
# NotImplementedError: Could not run
# 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cpu"),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
# RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
# IndexError: tuple index out of range
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_forward_mode_AD'),
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
# AssertionError: Booleans mismatch: True is not False
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
)
),
OpInfo('nn.functional.cosine_similarity',
aten_name="cosine_similarity",
dtypes=floating_types_and(torch.bfloat16),
@ -17787,6 +17826,20 @@ python_ref_db = [
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
),
),
PythonRefInfo(
"ops.nvprims.native_batch_norm",
torch_opinfo_name="native_batch_norm",
# Complex types are currently disabled
dtypes=floating_types(),
supports_out=False,
# This function is expected not to work with TorchRefsMode(strict=True)
decorators=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
# There's a discrepancy in returned shape between CPU and other devices
# AssertionError: Shapes torch.Size([0]) and torch.Size([2]) are not equal!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type="cpu"),
),
),
#
# Linear Algebra Operators
#