Add factory functions to python frontend (#89230)

- Add `full` nvprim to support factory functions because the full reference uses `empty` and `fill` while we have a full factory function.
- Change `full_like` reference to call `full` to avoid defining another nvprim.
- Enable support for new_zeros to enable `cudnn_batch_norm` decomposition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89230
Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
Ryan Spring
2022-12-06 07:16:19 +00:00
committed by PyTorch MergeBot
parent e645771e95
commit 3c9431f505
6 changed files with 258 additions and 3 deletions

View File

@ -234,6 +234,46 @@ class TestPrims(TestCase):
partitions = partitioner.propose_partitions()
self.assertEqual(len(partitions), 1)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)
def test_full(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
def func1(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device),)
def func2(size, value, b):
a = torch.full(size, value, dtype=dtype, device=device)
b_sin = b.sin()
return (torch.add(a, b_sin),)
def func3(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device), b)
def func4(size, value, b):
b_sin = b.sin()
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
def func5(size, value, b):
b_sin = b.sin()
a = torch.full(size, value, dtype=dtype, device=device)
a_sin = a.sin()
return (a, b_sin, a_sin)
for func in (func1, func3, func2, func3, func4, func5):
size = (3, 3)
value = 10
b = torch.randn(*size, dtype=dtype, device=device)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(size, value, b)
out = execute(gm, size, value, b, executor="strictly_nvfuser")
self.assertEqual(out, func(size, value, b))
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_empty_fusion(self, device):
@ -687,7 +727,13 @@ class TestPrims(TestCase):
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
ref_out = gm(sample.input, *sample.args)
for idx, (left, right) in enumerate(zip(out, ref_out)):
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
if idx == 3:
self.assertTrue(torch.all(torch.eq(left, 0)))
else:
self.assertEqual(left, right)
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA

View File

@ -364,6 +364,16 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
)
return result
def _is_full(self, func):
result = "torch.full" == torch.overrides.resolve_name(func) or (
func
in [
torch.ops.aten.full,
torch.ops.aten.full.names,
]
)
return result
def __torch_function__(
self,
orig_func: Callable,
@ -416,5 +426,8 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
warn("rand_like has ignored kwargs!")
return torch.ops.nvprims.rand_like(*args)
if self._is_full(orig_func):
return torch.ops.nvprims.full(*args, **kwargs)
# Then we use TorchRefsMode to interpret the rest
return super().__torch_function__(orig_func, types, args, kwargs)

View File

@ -8,6 +8,7 @@
from typing import Any, Dict, Optional, Tuple
import torch
import torch._prims_common as utils
from torch._prims_common import (
DimsSequenceType,
@ -15,6 +16,7 @@ from torch._prims_common import (
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
make_contiguous_strides_for,
NumberType,
ShapeType,
TensorLikeType,
)
@ -341,6 +343,26 @@ def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
return fd.ops.set(input)
def _full_nvfuser(
fd: Any,
shape: ShapeType,
fill_value: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
):
assert device != torch.device("cpu")
assert layout is None or layout is torch.strided
assert pin_memory is False
assert requires_grad is False
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.full(shape, fill_value, nvfuser_dtype)
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
@ -355,6 +377,70 @@ _nvfuser_impls["var"] = _var_nvfuser
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
_nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
_nvfuser_impls["full"] = _full_nvfuser
def register_full():
name = "full"
nvprim.define(
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
)
def _meta_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
requires_grad=False,
):
strides = make_contiguous_strides_for(size)
return torch._prims.TensorMeta(
None,
shape=size,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
):
return torch.full(
size,
fill_value,
out=out,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_impl)
prim_packet = getattr(torch.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Create a tensor with given size and filled with value"
p.impl_nvfuser = _nvfuser_impls["full"]
p.is_recomputable = _nvfuser_is_recomputable["full"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
# functorch.compile.min_cut_rematerialization_partition accepts a list of
# operators that can be recomputed in the backward pass. This list is used to
@ -397,6 +483,7 @@ _nvfuser_is_recomputable: Dict[str, bool] = {
"expm1": True,
"floor": True,
"fmod": True,
"full": True,
"ge": True,
"gt": True,
"imag": True,
@ -715,6 +802,7 @@ def register_nvprims():
register_view()
register_native_batch_norm()
register_rand_like()
register_full()
for name in nvprim_names:
main_prim = getattr(torch.ops.prims, name)

View File

@ -20,6 +20,7 @@ if hasattr(torch._C, "_nvfuser"):
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.uint8: DataType.Int32,
torch.bool: DataType.Bool,
# Python scalars
complex: DataType.ComplexDouble,

View File

@ -33,6 +33,7 @@ enum class RecordType {
VarianceMeanOp,
ViewOp,
PermuteOp,
FullOp
};
//! RecordFunctor is the base class record for operations recorded by
@ -1581,6 +1582,95 @@ struct BatchNormOpRecord : RecordFunctor {
bool channels_last_;
};
struct FullOpRecord : RecordFunctor {
FullOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
std::vector<int64_t>& shape,
Nvf::DataType dtype)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"ops.full",
RecordType::FullOp),
shape_(std::move(shape)),
dtype_(dtype) {}
virtual ~FullOpRecord() = default;
virtual RecordFunctor* clone() final {
return new FullOpRecord(*this);
}
//! Child specific hash function in lower 32 bits.
//! | 31 --- 24 | 23 -------------------------- 0 |
//! | Dtype | Shape hash code |
virtual size_t hash() const final {
auto result = RecordFunctor::hash();
size_t shape_hash = 0;
for (auto p : shape_) {
shape_hash ^= static_cast<size_t>(p);
}
result |= ((static_cast<size_t>(dtype_) & 0xff) << 24);
result |= (shape_hash & 0xffff);
return result;
}
virtual bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (auto child_ptr = dynamic_cast<const FullOpRecord*>(&other)) {
result = RecordFunctor::operator==(other);
if (result) {
result = (shape_.size() == child_ptr->shape_.size());
if (result) {
for (size_t i = 0; i < shape_.size(); ++i) {
if (shape_[i] != child_ptr->shape_[i]) {
result = false;
break;
}
}
}
}
}
return result;
}
void operator()(FusionDefinition& fd) final {
auto arg = fd.getFusionState(args_.at(0).index)->template as<Nvf::Val>();
std::vector<torch::jit::fuser::cuda::Val*> nvf_shape(
shape_.size(), nullptr);
for (const auto idx : c10::irange(shape_.size())) {
nvf_shape[idx] = Nvf::IrBuilder::create<Nvf::Int>(shape_.at(idx));
}
auto output = torch::jit::fuser::cuda::full(nvf_shape, arg, dtype_);
fd.setFusionState(outputs_.at(0).index, output);
}
virtual void print(std::ostream& os, bool close_function = true) const {
RecordFunctor::print(os, false);
os << ", shape=[";
bool first_arg = true;
for (auto p : shape_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << p;
}
os << "]";
os << ", dtype=" << dtypeToPyString(dtype_);
if (close_function) {
os << ")";
}
}
private:
//! Represents shape of new tensor
std::vector<int64_t> shape_;
//! Type of output
Nvf::DataType dtype_;
};
} // namespace nvfuser
//! Creating the template specialized hash and equal_to functions for a

View File

@ -1210,7 +1210,6 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("arg"),
py::arg("dims"),
py::return_value_policy::reference);
nvf_ops.def(
"squeeze",
[](nvfuser::FusionDefinition::Operators& self,
@ -1250,7 +1249,25 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("original_shape"),
py::arg("new_shape"),
py::return_value_policy::reference);
nvf_ops.def(
"full",
[](nvfuser::FusionDefinition::Operators& self,
std::vector<int64_t>& size,
nvfuser::Scalar arg,
Nvf::DataType dtype) -> nvfuser::Tensor {
nvfuser::FusionDefinition* fd = self.fusion_definition;
nvfuser::Tensor output = fd->defineTensor();
fd->defineRecord(new nvfuser::FullOpRecord(
{fd->recordingState(arg())},
{fd->recordingState(output())},
size,
dtype));
return output;
},
py::arg("size"),
py::arg("arg"),
py::arg("dtype"),
py::return_value_policy::reference);
nvf_ops.def(
"var",
[](nvfuser::FusionDefinition::Operators& self,