mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nvFuser support for torch.native_batch_norm (#85562)
This PR adds nvFuser's implementation for batch_norm as there's no reference yet (https://github.com/pytorch/pytorch/pull/81191) and no in-place copy support (https://github.com/pytorch/pytorch/pull/84545). Pull Request resolved: https://github.com/pytorch/pytorch/pull/85562 Approved by: https://github.com/kevinstephano, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d28a882319
commit
68a6113248
@ -395,6 +395,7 @@ class TestOperators(TestCase):
|
||||
skip('nn.functional.max_unpool1d'), # fails everywhere except on mac
|
||||
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
|
||||
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
||||
xfail("native_batch_norm"),
|
||||
|
||||
xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented
|
||||
}))
|
||||
@ -643,6 +644,7 @@ class TestOperators(TestCase):
|
||||
xfail("nn.functional.batch_norm", 'without_cudnn'),
|
||||
# view doesn't work on sparse
|
||||
xfail("to_sparse"),
|
||||
xfail("native_batch_norm"),
|
||||
}))
|
||||
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
@ -725,6 +727,7 @@ class TestOperators(TestCase):
|
||||
# ---------------------------- BUGS ------------------------------------
|
||||
# All of the following are bugs and need to be fixed
|
||||
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
|
||||
skip("native_batch_norm"),
|
||||
xfail('__getitem__', ''), # dynamic error
|
||||
xfail('linalg.eig'), # Uses aten::allclose
|
||||
xfail('linalg.householder_product'), # needs select_scatter
|
||||
@ -833,6 +836,7 @@ class TestOperators(TestCase):
|
||||
# erroring because running_mean and running_var aren't differentiable
|
||||
xfail('nn.functional.batch_norm'),
|
||||
xfail('nn.functional.batch_norm', 'without_cudnn'),
|
||||
xfail("native_batch_norm"),
|
||||
# ----------------------------------------------------------------------
|
||||
}
|
||||
|
||||
@ -1030,6 +1034,7 @@ class TestOperators(TestCase):
|
||||
xfail('linalg.vecdot', ''),
|
||||
xfail('segment_reduce', 'lengths'),
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail("native_batch_norm"),
|
||||
}))
|
||||
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
|
||||
if not op.supports_autograd:
|
||||
@ -1095,6 +1100,7 @@ class TestOperators(TestCase):
|
||||
xfail('nn.functional.dropout3d', ''),
|
||||
xfail('as_strided_scatter', ''),
|
||||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail("native_batch_norm"),
|
||||
}))
|
||||
def test_vjpvmap(self, device, dtype, op):
|
||||
# NB: there is no vjpvmap_has_batch_rule test because that is almost
|
||||
@ -1338,6 +1344,10 @@ class TestOperators(TestCase):
|
||||
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
xfail('to_sparse'), # Forward AD not implemented and no decomposition
|
||||
xfail('view_as_complex'), # RuntimeError: Tensor must have a last dimension with stride 1
|
||||
# RuntimeError: Batch norm got a batched tensor as
|
||||
# input while the running_mean or running_var, which will be updated in
|
||||
# place, were not batched.
|
||||
xfail("native_batch_norm"),
|
||||
}))
|
||||
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
|
@ -3287,6 +3287,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
|
||||
xfail('cat'),
|
||||
xfail('native_batch_norm'),
|
||||
}))
|
||||
def test_vmap_exhaustive(self, device, dtype, op):
|
||||
# needs to be fixed
|
||||
@ -3306,6 +3307,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||
xfail('cat'),
|
||||
xfail('complex'),
|
||||
xfail('copysign'),
|
||||
xfail('native_batch_norm'),
|
||||
xfail('histogram'),
|
||||
xfail('index_fill'),
|
||||
xfail('nansum'),
|
||||
|
@ -548,6 +548,69 @@ class TestPrims(TestCase):
|
||||
self.assertFalse(node.target == torch.ops.prims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.aten.add.default)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_native_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
# This test verifies that native_batch_norm is translated into nvprims
|
||||
# and can be executed with nvFuser
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_native_batch_norm,
|
||||
)
|
||||
|
||||
samples = sample_inputs_native_batch_norm(
|
||||
None, device, dtype, requires_grad=False
|
||||
)
|
||||
batch_norms = [
|
||||
torch.native_batch_norm,
|
||||
torch.ops.aten.native_batch_norm,
|
||||
torch.ops.aten.native_batch_norm.default,
|
||||
torch.ops.nvprims.native_batch_norm.default,
|
||||
]
|
||||
for sample, batch_norm in product(samples, batch_norms):
|
||||
if sample.input.numel() == 0:
|
||||
continue
|
||||
|
||||
def func(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(sample.input, *sample.args)
|
||||
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
includes_aten_batch_norm = any(
|
||||
torch.ops.aten.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_batch_norm)
|
||||
|
||||
includes_nvprims_batch_norm = any(
|
||||
torch.ops.nvprims.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_batch_norm)
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
|
@ -265,6 +265,12 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
and "aten.var_mean" in str(func)
|
||||
)
|
||||
|
||||
def _is_native_batch_norm(self, func):
|
||||
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
|
||||
func == torch.ops.aten.native_batch_norm.default
|
||||
or func == torch.ops.aten.native_batch_norm
|
||||
)
|
||||
|
||||
def _is_rand_like(self, func):
|
||||
result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
|
||||
func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
|
||||
@ -283,9 +289,14 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
|
||||
if self._is_var_mean(orig_func):
|
||||
return torch.ops.nvprims.var_mean(*args, **kwargs)
|
||||
|
||||
if self._is_native_batch_norm(orig_func):
|
||||
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
|
||||
|
||||
if self._is_rand_like(orig_func):
|
||||
if len(kwargs) > 0:
|
||||
warn("rand_like has ignored kwars!")
|
||||
return torch.ops.nvprims.rand_like(*args)
|
||||
|
||||
# Then we use TorchRefsMode to interpret the rest
|
||||
return super().__torch_function__(orig_func, types, args, kwargs)
|
||||
|
@ -136,6 +136,18 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
args = [args[0], original_shape, args[1]]
|
||||
return self.call_function(node.target, args, node.kwargs)
|
||||
|
||||
if node.target in [
|
||||
torch.ops.nvprims.native_batch_norm,
|
||||
torch.ops.nvprims.native_batch_norm.default,
|
||||
]:
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
assert len(args) == 8
|
||||
training = args[5]
|
||||
args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
|
||||
args = args[:5] + (training,) + args6_end
|
||||
return node.target.impl_nvfuser(fd, *args, **kwargs)
|
||||
|
||||
return super().run_node(node)
|
||||
|
||||
def call_function(self, target, args, kwargs):
|
||||
|
@ -210,6 +210,29 @@ _nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||||
)
|
||||
|
||||
|
||||
def _native_batch_norm_nvfuser(
|
||||
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
if weight is None:
|
||||
weight = fd.define_null_tensor()
|
||||
if bias is None:
|
||||
bias = fd.define_null_tensor()
|
||||
if running_mean is None:
|
||||
running_mean = fd.define_null_tensor()
|
||||
if running_var is None:
|
||||
running_var = fd.define_null_tensor()
|
||||
return fd.ops.batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
)
|
||||
|
||||
|
||||
def _broadcast_in_dim_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
@ -299,6 +322,7 @@ def _amin_nvfuser(
|
||||
return fd.ops.min(a, dims, keep_dims)
|
||||
|
||||
|
||||
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
|
||||
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
|
||||
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
|
||||
_nvfuser_impls["transpose"] = _transpose_nvfuser
|
||||
@ -312,6 +336,36 @@ _nvfuser_impls["amax"] = _amax_nvfuser
|
||||
_nvfuser_impls["amin"] = _amin_nvfuser
|
||||
|
||||
|
||||
def register_native_batch_norm():
|
||||
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
|
||||
name = "native_batch_norm"
|
||||
|
||||
nvprim.define(
|
||||
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
|
||||
+ "bool training, float momentum, float eps)"
|
||||
+ " -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
def _prim_impl(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return torch.native_batch_norm(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_autograd_impl.impl(
|
||||
name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
|
||||
)
|
||||
|
||||
prim_packet = torch.ops.nvprims.native_batch_norm
|
||||
prim = prim_packet.default
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Computes batch normalization."
|
||||
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_rand_like():
|
||||
name = "rand_like"
|
||||
|
||||
@ -471,6 +525,7 @@ def register_var_mean():
|
||||
def register_nvprims():
|
||||
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
|
||||
register_var_mean()
|
||||
register_native_batch_norm()
|
||||
register_rand_like()
|
||||
|
||||
for name in nvprim_names:
|
||||
|
@ -587,8 +587,11 @@ ForwardNormResult batch_norm(
|
||||
auto invstd_bcast = broadcast(unbiased_invstd, broadcast_mask);
|
||||
|
||||
// During inference, mean/invstd output are empty tensors
|
||||
mean = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
invstd = TensorViewBuilder().shape(std::vector<int64_t>{0}).build();
|
||||
// on CPU, but not on CUDA. We need to make sure we have the same
|
||||
// behavior as with eager mode on CUDA.
|
||||
mean = set(running_mean); // use set to avoid "trivial input forwarding NOT
|
||||
// IMPLEMENTED" error
|
||||
invstd = unbiased_invstd;
|
||||
y = mul(x_sub_mean, invstd_bcast);
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,12 @@ void FusionInterface::addOutput(Nvf::Val* output) const {
|
||||
|
||||
std::vector<at::Tensor> FusionInterface::execute(
|
||||
const at::ArrayRef<c10::IValue>& inputs) const {
|
||||
return fusionExecutorCachePtr()->runFusionWithInputs(inputs);
|
||||
// aliasOutputToInput always adds Tensors as outputs that we don't want
|
||||
// to return to the user. We need to remove them.
|
||||
auto count_output_aliases = fusionPtr()->getOutputAliasIndices().size();
|
||||
auto result = fusionExecutorCachePtr()->runFusionWithInputs(inputs);
|
||||
result.erase(result.begin(), result.begin() + count_output_aliases);
|
||||
return result;
|
||||
}
|
||||
|
||||
Nvf::FusionGuard FusionInterface::guard() const {
|
||||
|
@ -15,11 +15,13 @@ namespace nvfuser {
|
||||
enum class RecordType {
|
||||
Base = 0,
|
||||
Op,
|
||||
BatchNormOp,
|
||||
BroadcastOp,
|
||||
CastOp,
|
||||
Constant,
|
||||
End,
|
||||
Tensor,
|
||||
NullTensor,
|
||||
Output,
|
||||
ReductionOp,
|
||||
Scalar,
|
||||
@ -895,6 +897,41 @@ struct TensorRecord : RecordFunctor {
|
||||
bool is_cpu_;
|
||||
};
|
||||
|
||||
struct NullTensorRecord : RecordFunctor {
|
||||
NullTensorRecord(std::vector<State> _outputs)
|
||||
: RecordFunctor(
|
||||
{},
|
||||
std::move(_outputs),
|
||||
"null_tensor",
|
||||
RecordType::NullTensor) {}
|
||||
virtual ~NullTensorRecord() = default;
|
||||
virtual RecordFunctor* clone() final {
|
||||
return new NullTensorRecord(*this);
|
||||
}
|
||||
|
||||
//! Nothing extra necessary in hash
|
||||
//! Child specific hash function in lower 32 bits.
|
||||
//! | 31 --------------------------------------- 0 |
|
||||
//! | None |
|
||||
virtual size_t hash() const final {
|
||||
auto result = RecordFunctor::hash();
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual bool operator==(const RecordFunctor& other) const final {
|
||||
auto result = false;
|
||||
if (dynamic_cast<const NullTensorRecord*>(&other)) {
|
||||
result = RecordFunctor::operator==(other);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual void operator()(FusionDefinition& fd) final {
|
||||
Nvf::TensorView* tv = nullptr;
|
||||
fd.setFusionState(outputs_.at(0).index, tv);
|
||||
}
|
||||
};
|
||||
|
||||
//! Specialized Record Functor for recording FusionDefinition outputs.
|
||||
|
||||
template <class OutputType>
|
||||
@ -1313,6 +1350,70 @@ struct VarianceMeanOpRecord : NormOpRecord {
|
||||
}
|
||||
};
|
||||
|
||||
struct BatchNormOpRecord : RecordFunctor {
|
||||
BatchNormOpRecord(
|
||||
std::vector<State> args,
|
||||
std::vector<State> outputs,
|
||||
bool training,
|
||||
bool channels_last)
|
||||
: RecordFunctor(
|
||||
std::move(args),
|
||||
std::move(outputs),
|
||||
"ops.batch_norm",
|
||||
RecordType::BatchNormOp),
|
||||
training_(training),
|
||||
channels_last_(channels_last) {}
|
||||
virtual ~BatchNormOpRecord() = default;
|
||||
virtual RecordFunctor* clone() final {
|
||||
return new BatchNormOpRecord(*this);
|
||||
}
|
||||
|
||||
virtual bool operator==(const RecordFunctor& other) const final {
|
||||
auto result = false;
|
||||
if (auto child_ptr = dynamic_cast<const BatchNormOpRecord*>(&other)) {
|
||||
result = RecordFunctor::operator==(other);
|
||||
result = result && (training_ == child_ptr->training_);
|
||||
result = result && (channels_last_ == child_ptr->channels_last_);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual size_t hash() const final {
|
||||
auto result = RecordFunctor::hash();
|
||||
return result | (static_cast<size_t>(training_) << 28) |
|
||||
(static_cast<size_t>(channels_last_) << 29);
|
||||
}
|
||||
|
||||
void operator()(FusionDefinition& fd) final {
|
||||
auto x = fd.getFusionState(args_.at(0).index)->as<Nvf::TensorView>();
|
||||
auto weight = fd.getFusionState(args_.at(1).index)->as<Nvf::TensorView>();
|
||||
auto bias = fd.getFusionState(args_.at(2).index)->as<Nvf::TensorView>();
|
||||
auto running_mean =
|
||||
fd.getFusionState(args_.at(3).index)->as<Nvf::TensorView>();
|
||||
auto running_var =
|
||||
fd.getFusionState(args_.at(4).index)->as<Nvf::TensorView>();
|
||||
auto momentum = fd.getFusionState(args_.at(5).index)->as<Nvf::Val>();
|
||||
auto eps = fd.getFusionState(args_.at(6).index)->as<Nvf::Val>();
|
||||
auto output = Nvf::batch_norm(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training_,
|
||||
momentum,
|
||||
eps,
|
||||
channels_last_);
|
||||
fd.setFusionState(outputs_.at(0).index, output.output);
|
||||
fd.setFusionState(outputs_.at(1).index, output.mean);
|
||||
fd.setFusionState(outputs_.at(2).index, output.invstd);
|
||||
}
|
||||
|
||||
private:
|
||||
bool training_;
|
||||
bool channels_last_;
|
||||
};
|
||||
|
||||
} // namespace nvfuser
|
||||
|
||||
//! Creating the template specialized hash and equal_to functions for a
|
||||
|
@ -126,6 +126,16 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
self.defineRecord(new nvfuser::OutputRecord<Nvf::TensorView>(
|
||||
{self.recordingState(output())}));
|
||||
})
|
||||
.def(
|
||||
"define_null_tensor",
|
||||
[](nvfuser::FusionDefinition& self) -> nvfuser::Tensor {
|
||||
FUSER_PERF_SCOPE("FusionDefinition.define_null_tensor");
|
||||
nvfuser::Tensor out = self.defineTensor();
|
||||
self.defineRecord(
|
||||
new nvfuser::NullTensorRecord({self.recordingState(out())}));
|
||||
return out;
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"define_tensor",
|
||||
[](nvfuser::FusionDefinition& self,
|
||||
@ -1259,6 +1269,48 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
py::arg("correction"),
|
||||
py::arg("keepdim") = false,
|
||||
py::return_value_policy::reference);
|
||||
nvf_ops.def(
|
||||
"batch_norm",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
nvfuser::Tensor x,
|
||||
nvfuser::Tensor weight,
|
||||
nvfuser::Tensor bias,
|
||||
nvfuser::Tensor running_mean,
|
||||
nvfuser::Tensor running_var,
|
||||
bool training,
|
||||
nvfuser::Scalar momentum,
|
||||
nvfuser::Scalar eps,
|
||||
bool channels_last) -> decltype(auto) {
|
||||
FUSER_PERF_SCOPE("Operators.batch_norm");
|
||||
nvfuser::FusionDefinition* fd = self.fusion_definition;
|
||||
nvfuser::Tensor output = fd->defineTensor();
|
||||
nvfuser::Tensor mean = fd->defineTensor();
|
||||
nvfuser::Tensor invstd = fd->defineTensor();
|
||||
fd->defineRecord(new nvfuser::BatchNormOpRecord(
|
||||
{fd->recordingState(x()),
|
||||
fd->recordingState(weight()),
|
||||
fd->recordingState(bias()),
|
||||
fd->recordingState(running_mean()),
|
||||
fd->recordingState(running_var()),
|
||||
fd->recordingState(momentum()),
|
||||
fd->recordingState(eps())},
|
||||
{fd->recordingState(output()),
|
||||
fd->recordingState(mean()),
|
||||
fd->recordingState(invstd())},
|
||||
training,
|
||||
channels_last));
|
||||
return std::make_tuple(output, mean, invstd);
|
||||
},
|
||||
py::arg("x"),
|
||||
py::arg("weight").none(true),
|
||||
py::arg("bias").none(true),
|
||||
py::arg("running_mean").none(true),
|
||||
py::arg("running_var").none(true),
|
||||
py::arg("training"),
|
||||
py::arg("momentum"),
|
||||
py::arg("eps"),
|
||||
py::arg("channels_last") = false,
|
||||
py::return_value_policy::reference);
|
||||
nvf_ops.def(
|
||||
"broadcast_in_dim",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
|
@ -446,7 +446,22 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
# Test case for no optional kwargs
|
||||
# running_mean and running_var are required in evaluation mode (training: False) but not in training mode
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(None, None), kwargs={'training': True})
|
||||
yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True})
|
||||
|
||||
|
||||
def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs)
|
||||
for sample in samples:
|
||||
# torch.native_batch_norm does not support 0 numel tensors
|
||||
# IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
|
||||
if sample.input.numel() == 0:
|
||||
continue
|
||||
args = sample.args
|
||||
training = sample.kwargs.get('training', True)
|
||||
momentum = sample.kwargs.get('momentum', 0.5)
|
||||
eps = sample.kwargs.get('eps', 1e-5)
|
||||
yield SampleInput(sample.input, args=(args[2], args[3], args[0], args[1], training, momentum, eps))
|
||||
|
||||
|
||||
def sample_inputs_nn_activation_relu(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
@ -10640,6 +10655,30 @@ op_db: List[OpInfo] = [
|
||||
# possibly because of the welford implementation.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
|
||||
)),
|
||||
OpInfo('native_batch_norm',
|
||||
aten_name='native_batch_norm',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
sample_inputs_func=sample_inputs_native_batch_norm,
|
||||
skips=(
|
||||
# NotImplementedError: Could not run
|
||||
# 'aten::native_batch_norm.out' with arguments from the 'CPU' backend.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cpu"),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"),
|
||||
# RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0]
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"),
|
||||
# IndexError: tuple index out of range
|
||||
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_forward_mode_AD'),
|
||||
# RuntimeError: deepEquals(input.iValue, deepCopiedInput) INTERNAL ASSERT FAILED
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
# AssertionError: Booleans mismatch: True is not False
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
|
||||
)
|
||||
),
|
||||
OpInfo('nn.functional.cosine_similarity',
|
||||
aten_name="cosine_similarity",
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
@ -17787,6 +17826,20 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"ops.nvprims.native_batch_norm",
|
||||
torch_opinfo_name="native_batch_norm",
|
||||
# Complex types are currently disabled
|
||||
dtypes=floating_types(),
|
||||
supports_out=False,
|
||||
# This function is expected not to work with TorchRefsMode(strict=True)
|
||||
decorators=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
|
||||
# There's a discrepancy in returned shape between CPU and other devices
|
||||
# AssertionError: Shapes torch.Size([0]) and torch.Size([2]) are not equal!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type="cpu"),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Linear Algebra Operators
|
||||
#
|
||||
|
Reference in New Issue
Block a user