mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add factory functions to python frontend (#89230)
- Add `full` nvprim to support factory functions because the full reference uses `empty` and `fill` while we have a full factory function. - Change `full_like` reference to call `full` to avoid defining another nvprim. - Enable support for new_zeros to enable `cudnn_batch_norm` decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89230 Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
e645771e95
commit
3c9431f505
@ -234,6 +234,46 @@ class TestPrims(TestCase):
|
||||
partitions = partitioner.propose_partitions()
|
||||
self.assertEqual(len(partitions), 1)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
def test_full(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
def func1(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device),)
|
||||
|
||||
def func2(size, value, b):
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
b_sin = b.sin()
|
||||
return (torch.add(a, b_sin),)
|
||||
|
||||
def func3(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b)
|
||||
|
||||
def func4(size, value, b):
|
||||
b_sin = b.sin()
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
|
||||
|
||||
def func5(size, value, b):
|
||||
b_sin = b.sin()
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
a_sin = a.sin()
|
||||
return (a, b_sin, a_sin)
|
||||
|
||||
for func in (func1, func3, func2, func3, func4, func5):
|
||||
size = (3, 3)
|
||||
value = 10
|
||||
b = torch.randn(*size, dtype=dtype, device=device)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(size, value, b)
|
||||
|
||||
out = execute(gm, size, value, b, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(size, value, b))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_empty_fusion(self, device):
|
||||
@ -687,7 +727,13 @@ class TestPrims(TestCase):
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
ref_out = gm(sample.input, *sample.args)
|
||||
for idx, (left, right) in enumerate(zip(out, ref_out)):
|
||||
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
|
||||
if idx == 3:
|
||||
self.assertTrue(torch.all(torch.eq(left, 0)))
|
||||
else:
|
||||
self.assertEqual(left, right)
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
|
@ -364,6 +364,16 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
)
|
||||
return result
|
||||
|
||||
def _is_full(self, func):
|
||||
result = "torch.full" == torch.overrides.resolve_name(func) or (
|
||||
func
|
||||
in [
|
||||
torch.ops.aten.full,
|
||||
torch.ops.aten.full.names,
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
orig_func: Callable,
|
||||
@ -416,5 +426,8 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
warn("rand_like has ignored kwargs!")
|
||||
return torch.ops.nvprims.rand_like(*args)
|
||||
|
||||
if self._is_full(orig_func):
|
||||
return torch.ops.nvprims.full(*args, **kwargs)
|
||||
|
||||
# Then we use TorchRefsMode to interpret the rest
|
||||
return super().__torch_function__(orig_func, types, args, kwargs)
|
||||
|
@ -8,6 +8,7 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
|
||||
from torch._prims_common import (
|
||||
DimsSequenceType,
|
||||
@ -15,6 +16,7 @@ from torch._prims_common import (
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
getnvFuserDtype,
|
||||
make_contiguous_strides_for,
|
||||
NumberType,
|
||||
ShapeType,
|
||||
TensorLikeType,
|
||||
)
|
||||
@ -341,6 +343,26 @@ def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
|
||||
return fd.ops.set(input)
|
||||
|
||||
|
||||
def _full_nvfuser(
|
||||
fd: Any,
|
||||
shape: ShapeType,
|
||||
fill_value: NumberType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
):
|
||||
assert device != torch.device("cpu")
|
||||
assert layout is None or layout is torch.strided
|
||||
assert pin_memory is False
|
||||
assert requires_grad is False
|
||||
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
|
||||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||||
return fd.ops.full(shape, fill_value, nvfuser_dtype)
|
||||
|
||||
|
||||
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
|
||||
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
|
||||
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
|
||||
@ -355,6 +377,70 @@ _nvfuser_impls["var"] = _var_nvfuser
|
||||
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
|
||||
_nvfuser_impls["amax"] = _amax_nvfuser
|
||||
_nvfuser_impls["amin"] = _amin_nvfuser
|
||||
_nvfuser_impls["full"] = _full_nvfuser
|
||||
|
||||
|
||||
def register_full():
|
||||
name = "full"
|
||||
|
||||
nvprim.define(
|
||||
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
|
||||
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
|
||||
)
|
||||
|
||||
def _meta_impl(
|
||||
size,
|
||||
fill_value,
|
||||
*,
|
||||
out=None,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
requires_grad=False,
|
||||
):
|
||||
strides = make_contiguous_strides_for(size)
|
||||
return torch._prims.TensorMeta(
|
||||
None,
|
||||
shape=size,
|
||||
strides=strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _prim_impl(
|
||||
size,
|
||||
fill_value,
|
||||
*,
|
||||
out=None,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=False,
|
||||
requires_grad=False,
|
||||
):
|
||||
return torch.full(
|
||||
size,
|
||||
fill_value,
|
||||
out=out,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_meta_impl.impl(name, _meta_impl)
|
||||
|
||||
prim_packet = getattr(torch.ops.nvprims, name)
|
||||
prim = prim_packet.default
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Create a tensor with given size and filled with value"
|
||||
p.impl_nvfuser = _nvfuser_impls["full"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["full"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# functorch.compile.min_cut_rematerialization_partition accepts a list of
|
||||
# operators that can be recomputed in the backward pass. This list is used to
|
||||
@ -397,6 +483,7 @@ _nvfuser_is_recomputable: Dict[str, bool] = {
|
||||
"expm1": True,
|
||||
"floor": True,
|
||||
"fmod": True,
|
||||
"full": True,
|
||||
"ge": True,
|
||||
"gt": True,
|
||||
"imag": True,
|
||||
@ -715,6 +802,7 @@ def register_nvprims():
|
||||
register_view()
|
||||
register_native_batch_norm()
|
||||
register_rand_like()
|
||||
register_full()
|
||||
|
||||
for name in nvprim_names:
|
||||
main_prim = getattr(torch.ops.prims, name)
|
||||
|
@ -20,6 +20,7 @@ if hasattr(torch._C, "_nvfuser"):
|
||||
torch.bfloat16: DataType.BFloat16,
|
||||
torch.long: DataType.Int,
|
||||
torch.int: DataType.Int32,
|
||||
torch.uint8: DataType.Int32,
|
||||
torch.bool: DataType.Bool,
|
||||
# Python scalars
|
||||
complex: DataType.ComplexDouble,
|
||||
|
@ -33,6 +33,7 @@ enum class RecordType {
|
||||
VarianceMeanOp,
|
||||
ViewOp,
|
||||
PermuteOp,
|
||||
FullOp
|
||||
};
|
||||
|
||||
//! RecordFunctor is the base class record for operations recorded by
|
||||
@ -1581,6 +1582,95 @@ struct BatchNormOpRecord : RecordFunctor {
|
||||
bool channels_last_;
|
||||
};
|
||||
|
||||
struct FullOpRecord : RecordFunctor {
|
||||
FullOpRecord(
|
||||
std::vector<State> _args,
|
||||
std::vector<State> _outputs,
|
||||
std::vector<int64_t>& shape,
|
||||
Nvf::DataType dtype)
|
||||
: RecordFunctor(
|
||||
std::move(_args),
|
||||
std::move(_outputs),
|
||||
"ops.full",
|
||||
RecordType::FullOp),
|
||||
shape_(std::move(shape)),
|
||||
dtype_(dtype) {}
|
||||
virtual ~FullOpRecord() = default;
|
||||
virtual RecordFunctor* clone() final {
|
||||
return new FullOpRecord(*this);
|
||||
}
|
||||
|
||||
//! Child specific hash function in lower 32 bits.
|
||||
//! | 31 --- 24 | 23 -------------------------- 0 |
|
||||
//! | Dtype | Shape hash code |
|
||||
virtual size_t hash() const final {
|
||||
auto result = RecordFunctor::hash();
|
||||
size_t shape_hash = 0;
|
||||
for (auto p : shape_) {
|
||||
shape_hash ^= static_cast<size_t>(p);
|
||||
}
|
||||
result |= ((static_cast<size_t>(dtype_) & 0xff) << 24);
|
||||
result |= (shape_hash & 0xffff);
|
||||
return result;
|
||||
}
|
||||
|
||||
virtual bool operator==(const RecordFunctor& other) const final {
|
||||
auto result = false;
|
||||
if (auto child_ptr = dynamic_cast<const FullOpRecord*>(&other)) {
|
||||
result = RecordFunctor::operator==(other);
|
||||
if (result) {
|
||||
result = (shape_.size() == child_ptr->shape_.size());
|
||||
if (result) {
|
||||
for (size_t i = 0; i < shape_.size(); ++i) {
|
||||
if (shape_[i] != child_ptr->shape_[i]) {
|
||||
result = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void operator()(FusionDefinition& fd) final {
|
||||
auto arg = fd.getFusionState(args_.at(0).index)->template as<Nvf::Val>();
|
||||
|
||||
std::vector<torch::jit::fuser::cuda::Val*> nvf_shape(
|
||||
shape_.size(), nullptr);
|
||||
for (const auto idx : c10::irange(shape_.size())) {
|
||||
nvf_shape[idx] = Nvf::IrBuilder::create<Nvf::Int>(shape_.at(idx));
|
||||
}
|
||||
auto output = torch::jit::fuser::cuda::full(nvf_shape, arg, dtype_);
|
||||
fd.setFusionState(outputs_.at(0).index, output);
|
||||
}
|
||||
|
||||
virtual void print(std::ostream& os, bool close_function = true) const {
|
||||
RecordFunctor::print(os, false);
|
||||
os << ", shape=[";
|
||||
bool first_arg = true;
|
||||
for (auto p : shape_) {
|
||||
if (first_arg) {
|
||||
first_arg = false;
|
||||
} else {
|
||||
os << ", ";
|
||||
}
|
||||
os << p;
|
||||
}
|
||||
os << "]";
|
||||
os << ", dtype=" << dtypeToPyString(dtype_);
|
||||
if (close_function) {
|
||||
os << ")";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
//! Represents shape of new tensor
|
||||
std::vector<int64_t> shape_;
|
||||
//! Type of output
|
||||
Nvf::DataType dtype_;
|
||||
};
|
||||
|
||||
} // namespace nvfuser
|
||||
|
||||
//! Creating the template specialized hash and equal_to functions for a
|
||||
|
@ -1210,7 +1210,6 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
py::arg("arg"),
|
||||
py::arg("dims"),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
nvf_ops.def(
|
||||
"squeeze",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
@ -1250,7 +1249,25 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
py::arg("original_shape"),
|
||||
py::arg("new_shape"),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
nvf_ops.def(
|
||||
"full",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
std::vector<int64_t>& size,
|
||||
nvfuser::Scalar arg,
|
||||
Nvf::DataType dtype) -> nvfuser::Tensor {
|
||||
nvfuser::FusionDefinition* fd = self.fusion_definition;
|
||||
nvfuser::Tensor output = fd->defineTensor();
|
||||
fd->defineRecord(new nvfuser::FullOpRecord(
|
||||
{fd->recordingState(arg())},
|
||||
{fd->recordingState(output())},
|
||||
size,
|
||||
dtype));
|
||||
return output;
|
||||
},
|
||||
py::arg("size"),
|
||||
py::arg("arg"),
|
||||
py::arg("dtype"),
|
||||
py::return_value_policy::reference);
|
||||
nvf_ops.def(
|
||||
"var",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
|
Reference in New Issue
Block a user