Add nvFuser support for torch.Tensor.view (#84634)

This is an alternative to https://github.com/pytorch/pytorch/pull/83739. While PrimTorch has `view` as a reference, we would like to use nvFuser's implementation for `view` for now. Later we might transition to PrimTorch's `torch._refs.view`.

See `test_nvprims_view` for examples of things that are now sent to nvFuser. Note that nvFuser's `view` is a copy-like operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84634
Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
Ivan Yashchuk
2022-10-14 12:08:02 +00:00
committed by PyTorch MergeBot
parent b48deedb77
commit fd80684784
7 changed files with 248 additions and 5 deletions

View File

@ -378,9 +378,15 @@ class TestCommon(TestCase):
if executor == "nvfuser" and isinstance(op, ReductionPythonRefInfo):
skip_zero_dim = True
# skip zero-dim tensors for some composites of reduction operations
normalization_ops = ["_refs.softmax", "_refs.logsumexp", "_refs.log_softmax", "_refs.sum_to_size"]
if executor == "nvfuser" and op.name in normalization_ops:
# skip zero-dim tensors for some composites of reduction operations and view
skip_zero_dim_ops = [
"_refs.softmax",
"_refs.logsumexp",
"_refs.log_softmax",
"_refs.sum_to_size",
"ops.nvprims.view",
]
if executor == "nvfuser" and op.name in skip_zero_dim_ops:
skip_zero_dim = True
from torch._prims.executor import make_traced

View File

@ -688,6 +688,53 @@ class TestPrims(TestCase):
)
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float16, torch.float32)
def test_nvprims_view(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((3, 4, 5))
def func1(a):
return a.view(tuple(reversed(a.shape)))
def func2(a):
return a.reshape(tuple(reversed(a.shape)))
def func3(a):
return torch.view_copy(a, tuple(reversed(a.shape)))
def func4(a):
return torch.reshape(a, tuple(reversed(a.shape)))
def func5(a):
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
def func6(a):
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
def func7(a):
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
for func in (func1, func2, func3, func4, func5, func6, func7):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_nvprims_view = any(
torch.ops.nvprims.view.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_view)
# Try executing the graph
out = execute(gm, a, executor="strictly_nvfuser")
self.assertEqual(out, func(a))
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float16)

View File

@ -251,13 +251,13 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
aten_ops_to_skip = (
"aten.transpose.int",
"aten.t.default",
"aten.view.default",
"aten.unsqueeze.default",
"aten.permute.default",
"aten._log_softmax.default",
"aten._log_softmax_backward_data.default",
"aten.expand.default",
)
self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
super().__init__(
strict=False,
should_fallback_fn=functools.partial(
@ -276,6 +276,18 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
and "aten.var_mean" in str(func)
)
def _is_view_or_reshape(self, func):
allowed_ops = {
"torch.Tensor.view",
"torch.Tensor.reshape",
"torch.view_copy",
"torch.reshape",
"aten.view.default",
"aten._unsafe_view.default",
"aten.view_copy.default",
} - set(self.skip_ops)
return torch.overrides.resolve_name(func) in allowed_ops
def _is_native_batch_norm(self, func):
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
func == torch.ops.aten.native_batch_norm.default
@ -301,12 +313,21 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
if self._is_view_or_reshape(orig_func):
a, *shape = args
shape = torch._prims_common.extract_shape_from_varargs(
shape, validate=False
) # type: ignore[assignment]
if len(kwargs) > 0:
warn("view has ignored kwargs!")
return torch.ops.nvprims.view(a, shape)
if self._is_native_batch_norm(orig_func):
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
if self._is_rand_like(orig_func):
if len(kwargs) > 0:
warn("rand_like has ignored kwars!")
warn("rand_like has ignored kwargs!")
return torch.ops.nvprims.rand_like(*args)
# Then we use TorchRefsMode to interpret the rest

View File

@ -263,6 +263,15 @@ def _view_of_nvfuser(fd, a):
return fd.ops.set(a)
def _view_nvfuser(
fd,
a,
a_shape,
new_shape,
):
return fd.ops.view(a, a_shape, new_shape)
def _sum_nvfuser(
fd: Any,
a: TensorLikeType,
@ -334,6 +343,7 @@ _nvfuser_impls["clone"] = _clone_nvfuser
_nvfuser_impls["transpose"] = _transpose_nvfuser
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
_nvfuser_impls["view_of"] = _view_of_nvfuser
_nvfuser_impls["view"] = _view_nvfuser
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
_nvfuser_impls["sum"] = _sum_nvfuser
_nvfuser_impls["var"] = _var_nvfuser
@ -528,9 +538,43 @@ def register_var_mean():
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_view():
"""This function is used to register the view function in torch.ops.view module."""
# View is implemented as a decomposition into prims.split_dim,
# prims.collapse_dim, and prims.reshape, but we would like to intercept
# non-decomposed view for now
name = "view"
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(a, original_shape, new_shape):
return a.reshape(new_shape)
nvprim_impl.impl(name, _prim_impl)
prim_packet = torch.ops.nvprims.view
prim = prim_packet.default
def _view_no_original_shape_overload_impl(a, shape):
if list(a.shape) == list(shape):
return torch.ops.nvprims.view_of(a)
return torch.ops.nvprims.view.default(a, a.shape, shape)
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
p.impl_nvfuser = _nvfuser_impls["view"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
register_view()
register_native_batch_norm()
register_rand_like()

View File

@ -6,6 +6,8 @@
#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <algorithm>
namespace nvfuser {
//! This enum it to give a Record Type for record hashing given that the
@ -29,6 +31,7 @@ enum class RecordType {
Start,
VarianceOp,
VarianceMeanOp,
ViewOp,
PermuteOp,
};
@ -277,6 +280,99 @@ struct OpRecord : RecordFunctor {
std::function<OutType(ArgTypes...)> fusion_op_;
};
struct ViewOpRecord : RecordFunctor {
ViewOpRecord(
std::vector<State> _args,
std::vector<State> _outputs,
std::vector<int64_t>& original_shape,
std::vector<int64_t>& new_shape)
: RecordFunctor(
std::move(_args),
std::move(_outputs),
"ops.view",
RecordType::ViewOp),
original_shape_(std::move(original_shape)),
new_shape_(std::move(new_shape)) {}
virtual ~ViewOpRecord() = default;
virtual RecordFunctor* clone() final {
return new ViewOpRecord(*this);
}
//! Child specific hash function in lower 32 bits.
//! | 31 -------------- 16 | 15 -------------- 0 |
//! | original_shape hash | new_shape hash |
virtual size_t hash() const final {
auto result = RecordFunctor::hash();
size_t new_shape_hash = 0;
for (auto shape : new_shape_) {
new_shape_hash ^= static_cast<size_t>(shape);
}
size_t original_shape_hash = 0;
for (auto shape : original_shape_) {
original_shape_hash |= 1 << ((new_shape_.size() - 1) - shape);
}
original_shape_hash = (original_shape_hash & 0xffff) << 16;
return result | original_shape_hash | (new_shape_hash & 0xffff);
}
virtual bool operator==(const RecordFunctor& other) const final {
auto result = false;
if (auto child_ptr = dynamic_cast<const ViewOpRecord*>(&other)) {
result = RecordFunctor::operator==(other);
result &= std::equal(
original_shape_.begin(),
original_shape_.end(),
child_ptr->original_shape_.begin());
result &= std::equal(
new_shape_.begin(), new_shape_.end(), child_ptr->new_shape_.begin());
}
return result;
}
void operator()(FusionDefinition& fd) final {
auto arg =
fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
auto output =
torch::jit::fuser::cuda::view(arg, original_shape_, new_shape_);
fd.setFusionState(outputs_.at(0).index, output);
}
virtual void print(std::ostream& os, bool close_function = true) const {
RecordFunctor::print(os, false);
os << ", original_shape=[";
bool first_arg = true;
for (auto shape : original_shape_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << shape;
}
os << "]";
os << ", new_shape=[";
first_arg = true;
for (auto shape : new_shape_) {
if (first_arg) {
first_arg = false;
} else {
os << ", ";
}
os << shape;
}
os << "]";
if (close_function) {
os << ")";
}
}
private:
//! Represents the tensor dimensions of the input tensor.
std::vector<int64_t> original_shape_;
//! Represents the tensor dimensions of the output tensor.
std::vector<int64_t> new_shape_;
};
struct PermuteOpRecord : RecordFunctor {
PermuteOpRecord(
std::vector<State> _args,

View File

@ -1222,6 +1222,26 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("original_shape"),
py::arg("dim"),
py::return_value_policy::reference);
nvf_ops.def(
"view",
[](nvfuser::FusionDefinition::Operators& self,
nvfuser::Tensor arg,
std::vector<int64_t>& original_shape,
std::vector<int64_t>& new_shape) -> nvfuser::Tensor {
nvfuser::FusionDefinition* fd = self.fusion_definition;
nvfuser::Tensor output = fd->defineTensor();
self.fusion_definition->defineRecord(new nvfuser::ViewOpRecord(
{fd->recordingState(arg())},
{fd->recordingState(output())},
original_shape,
new_shape));
return output;
},
py::arg("arg"),
py::arg("original_shape"),
py::arg("new_shape"),
py::return_value_policy::reference);
nvf_ops.def(
"var",
[](nvfuser::FusionDefinition::Operators& self,

View File

@ -17619,6 +17619,15 @@ python_ref_db = [
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type="cpu"),
),
),
PythonRefInfo(
"ops.nvprims.view",
torch_opinfo_name="view",
validate_view_consistency=False,
# This function is expected not to work with TorchRefsMode(strict=True)
decorators=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
),
),
#
# Linear Algebra Operators
#