mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nvFuser support for torch.Tensor.view (#84634)
This is an alternative to https://github.com/pytorch/pytorch/pull/83739. While PrimTorch has `view` as a reference, we would like to use nvFuser's implementation for `view` for now. Later we might transition to PrimTorch's `torch._refs.view`. See `test_nvprims_view` for examples of things that are now sent to nvFuser. Note that nvFuser's `view` is a copy-like operation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84634 Approved by: https://github.com/kevinstephano, https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
b48deedb77
commit
fd80684784
@ -378,9 +378,15 @@ class TestCommon(TestCase):
|
||||
if executor == "nvfuser" and isinstance(op, ReductionPythonRefInfo):
|
||||
skip_zero_dim = True
|
||||
|
||||
# skip zero-dim tensors for some composites of reduction operations
|
||||
normalization_ops = ["_refs.softmax", "_refs.logsumexp", "_refs.log_softmax", "_refs.sum_to_size"]
|
||||
if executor == "nvfuser" and op.name in normalization_ops:
|
||||
# skip zero-dim tensors for some composites of reduction operations and view
|
||||
skip_zero_dim_ops = [
|
||||
"_refs.softmax",
|
||||
"_refs.logsumexp",
|
||||
"_refs.log_softmax",
|
||||
"_refs.sum_to_size",
|
||||
"ops.nvprims.view",
|
||||
]
|
||||
if executor == "nvfuser" and op.name in skip_zero_dim_ops:
|
||||
skip_zero_dim = True
|
||||
|
||||
from torch._prims.executor import make_traced
|
||||
|
||||
@ -688,6 +688,53 @@ class TestPrims(TestCase):
|
||||
)
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((3, 4, 5))
|
||||
|
||||
def func1(a):
|
||||
return a.view(tuple(reversed(a.shape)))
|
||||
|
||||
def func2(a):
|
||||
return a.reshape(tuple(reversed(a.shape)))
|
||||
|
||||
def func3(a):
|
||||
return torch.view_copy(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func4(a):
|
||||
return torch.reshape(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func5(a):
|
||||
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func6(a):
|
||||
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func7(a):
|
||||
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
for func in (func1, func2, func3, func4, func5, func6, func7):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_nvprims_view = any(
|
||||
torch.ops.nvprims.view.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_view)
|
||||
|
||||
# Try executing the graph
|
||||
out = execute(gm, a, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(a))
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
|
||||
@ -251,13 +251,13 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
aten_ops_to_skip = (
|
||||
"aten.transpose.int",
|
||||
"aten.t.default",
|
||||
"aten.view.default",
|
||||
"aten.unsqueeze.default",
|
||||
"aten.permute.default",
|
||||
"aten._log_softmax.default",
|
||||
"aten._log_softmax_backward_data.default",
|
||||
"aten.expand.default",
|
||||
)
|
||||
self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
|
||||
super().__init__(
|
||||
strict=False,
|
||||
should_fallback_fn=functools.partial(
|
||||
@ -276,6 +276,18 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
and "aten.var_mean" in str(func)
|
||||
)
|
||||
|
||||
def _is_view_or_reshape(self, func):
|
||||
allowed_ops = {
|
||||
"torch.Tensor.view",
|
||||
"torch.Tensor.reshape",
|
||||
"torch.view_copy",
|
||||
"torch.reshape",
|
||||
"aten.view.default",
|
||||
"aten._unsafe_view.default",
|
||||
"aten.view_copy.default",
|
||||
} - set(self.skip_ops)
|
||||
return torch.overrides.resolve_name(func) in allowed_ops
|
||||
|
||||
def _is_native_batch_norm(self, func):
|
||||
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
|
||||
func == torch.ops.aten.native_batch_norm.default
|
||||
@ -301,12 +313,21 @@ class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
if self._is_var_mean(orig_func):
|
||||
return torch.ops.nvprims.var_mean(*args, **kwargs)
|
||||
|
||||
if self._is_view_or_reshape(orig_func):
|
||||
a, *shape = args
|
||||
shape = torch._prims_common.extract_shape_from_varargs(
|
||||
shape, validate=False
|
||||
) # type: ignore[assignment]
|
||||
if len(kwargs) > 0:
|
||||
warn("view has ignored kwargs!")
|
||||
return torch.ops.nvprims.view(a, shape)
|
||||
|
||||
if self._is_native_batch_norm(orig_func):
|
||||
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
|
||||
|
||||
if self._is_rand_like(orig_func):
|
||||
if len(kwargs) > 0:
|
||||
warn("rand_like has ignored kwars!")
|
||||
warn("rand_like has ignored kwargs!")
|
||||
return torch.ops.nvprims.rand_like(*args)
|
||||
|
||||
# Then we use TorchRefsMode to interpret the rest
|
||||
|
||||
@ -263,6 +263,15 @@ def _view_of_nvfuser(fd, a):
|
||||
return fd.ops.set(a)
|
||||
|
||||
|
||||
def _view_nvfuser(
|
||||
fd,
|
||||
a,
|
||||
a_shape,
|
||||
new_shape,
|
||||
):
|
||||
return fd.ops.view(a, a_shape, new_shape)
|
||||
|
||||
|
||||
def _sum_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
@ -334,6 +343,7 @@ _nvfuser_impls["clone"] = _clone_nvfuser
|
||||
_nvfuser_impls["transpose"] = _transpose_nvfuser
|
||||
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
|
||||
_nvfuser_impls["view_of"] = _view_of_nvfuser
|
||||
_nvfuser_impls["view"] = _view_nvfuser
|
||||
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
|
||||
_nvfuser_impls["sum"] = _sum_nvfuser
|
||||
_nvfuser_impls["var"] = _var_nvfuser
|
||||
@ -528,9 +538,43 @@ def register_var_mean():
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_view():
|
||||
"""This function is used to register the view function in torch.ops.view module."""
|
||||
# View is implemented as a decomposition into prims.split_dim,
|
||||
# prims.collapse_dim, and prims.reshape, but we would like to intercept
|
||||
# non-decomposed view for now
|
||||
name = "view"
|
||||
|
||||
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
|
||||
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
|
||||
|
||||
# This function is used under _AutoDispatchBelowAutograd context
|
||||
def _prim_impl(a, original_shape, new_shape):
|
||||
return a.reshape(new_shape)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
|
||||
prim_packet = torch.ops.nvprims.view
|
||||
prim = prim_packet.default
|
||||
|
||||
def _view_no_original_shape_overload_impl(a, shape):
|
||||
if list(a.shape) == list(shape):
|
||||
return torch.ops.nvprims.view_of(a)
|
||||
return torch.ops.nvprims.view.default(a, a.shape, shape)
|
||||
|
||||
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
|
||||
p.impl_nvfuser = _nvfuser_impls["view"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_nvprims():
|
||||
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
|
||||
register_var_mean()
|
||||
register_view()
|
||||
register_native_batch_norm()
|
||||
register_rand_like()
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@
|
||||
#include <torch/csrc/jit/codegen/cuda/python_frontend/fusion_definition.h>
|
||||
#include <torch/csrc/jit/codegen/cuda/utils.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace nvfuser {
|
||||
|
||||
//! This enum it to give a Record Type for record hashing given that the
|
||||
@ -29,6 +31,7 @@ enum class RecordType {
|
||||
Start,
|
||||
VarianceOp,
|
||||
VarianceMeanOp,
|
||||
ViewOp,
|
||||
PermuteOp,
|
||||
};
|
||||
|
||||
@ -277,6 +280,99 @@ struct OpRecord : RecordFunctor {
|
||||
std::function<OutType(ArgTypes...)> fusion_op_;
|
||||
};
|
||||
|
||||
struct ViewOpRecord : RecordFunctor {
|
||||
ViewOpRecord(
|
||||
std::vector<State> _args,
|
||||
std::vector<State> _outputs,
|
||||
std::vector<int64_t>& original_shape,
|
||||
std::vector<int64_t>& new_shape)
|
||||
: RecordFunctor(
|
||||
std::move(_args),
|
||||
std::move(_outputs),
|
||||
"ops.view",
|
||||
RecordType::ViewOp),
|
||||
original_shape_(std::move(original_shape)),
|
||||
new_shape_(std::move(new_shape)) {}
|
||||
virtual ~ViewOpRecord() = default;
|
||||
virtual RecordFunctor* clone() final {
|
||||
return new ViewOpRecord(*this);
|
||||
}
|
||||
|
||||
//! Child specific hash function in lower 32 bits.
|
||||
//! | 31 -------------- 16 | 15 -------------- 0 |
|
||||
//! | original_shape hash | new_shape hash |
|
||||
virtual size_t hash() const final {
|
||||
auto result = RecordFunctor::hash();
|
||||
size_t new_shape_hash = 0;
|
||||
for (auto shape : new_shape_) {
|
||||
new_shape_hash ^= static_cast<size_t>(shape);
|
||||
}
|
||||
size_t original_shape_hash = 0;
|
||||
for (auto shape : original_shape_) {
|
||||
original_shape_hash |= 1 << ((new_shape_.size() - 1) - shape);
|
||||
}
|
||||
original_shape_hash = (original_shape_hash & 0xffff) << 16;
|
||||
return result | original_shape_hash | (new_shape_hash & 0xffff);
|
||||
}
|
||||
|
||||
virtual bool operator==(const RecordFunctor& other) const final {
|
||||
auto result = false;
|
||||
if (auto child_ptr = dynamic_cast<const ViewOpRecord*>(&other)) {
|
||||
result = RecordFunctor::operator==(other);
|
||||
result &= std::equal(
|
||||
original_shape_.begin(),
|
||||
original_shape_.end(),
|
||||
child_ptr->original_shape_.begin());
|
||||
result &= std::equal(
|
||||
new_shape_.begin(), new_shape_.end(), child_ptr->new_shape_.begin());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void operator()(FusionDefinition& fd) final {
|
||||
auto arg =
|
||||
fd.getFusionState(args_.at(0).index)->template as<Nvf::TensorView>();
|
||||
auto output =
|
||||
torch::jit::fuser::cuda::view(arg, original_shape_, new_shape_);
|
||||
fd.setFusionState(outputs_.at(0).index, output);
|
||||
}
|
||||
|
||||
virtual void print(std::ostream& os, bool close_function = true) const {
|
||||
RecordFunctor::print(os, false);
|
||||
os << ", original_shape=[";
|
||||
bool first_arg = true;
|
||||
for (auto shape : original_shape_) {
|
||||
if (first_arg) {
|
||||
first_arg = false;
|
||||
} else {
|
||||
os << ", ";
|
||||
}
|
||||
os << shape;
|
||||
}
|
||||
os << "]";
|
||||
os << ", new_shape=[";
|
||||
first_arg = true;
|
||||
for (auto shape : new_shape_) {
|
||||
if (first_arg) {
|
||||
first_arg = false;
|
||||
} else {
|
||||
os << ", ";
|
||||
}
|
||||
os << shape;
|
||||
}
|
||||
os << "]";
|
||||
if (close_function) {
|
||||
os << ")";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
//! Represents the tensor dimensions of the input tensor.
|
||||
std::vector<int64_t> original_shape_;
|
||||
//! Represents the tensor dimensions of the output tensor.
|
||||
std::vector<int64_t> new_shape_;
|
||||
};
|
||||
|
||||
struct PermuteOpRecord : RecordFunctor {
|
||||
PermuteOpRecord(
|
||||
std::vector<State> _args,
|
||||
|
||||
@ -1222,6 +1222,26 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
py::arg("original_shape"),
|
||||
py::arg("dim"),
|
||||
py::return_value_policy::reference);
|
||||
nvf_ops.def(
|
||||
"view",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
nvfuser::Tensor arg,
|
||||
std::vector<int64_t>& original_shape,
|
||||
std::vector<int64_t>& new_shape) -> nvfuser::Tensor {
|
||||
nvfuser::FusionDefinition* fd = self.fusion_definition;
|
||||
nvfuser::Tensor output = fd->defineTensor();
|
||||
self.fusion_definition->defineRecord(new nvfuser::ViewOpRecord(
|
||||
{fd->recordingState(arg())},
|
||||
{fd->recordingState(output())},
|
||||
original_shape,
|
||||
new_shape));
|
||||
return output;
|
||||
},
|
||||
py::arg("arg"),
|
||||
py::arg("original_shape"),
|
||||
py::arg("new_shape"),
|
||||
py::return_value_policy::reference);
|
||||
|
||||
nvf_ops.def(
|
||||
"var",
|
||||
[](nvfuser::FusionDefinition::Operators& self,
|
||||
|
||||
@ -17619,6 +17619,15 @@ python_ref_db = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type="cpu"),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"ops.nvprims.view",
|
||||
torch_opinfo_name="view",
|
||||
validate_view_consistency=False,
|
||||
# This function is expected not to work with TorchRefsMode(strict=True)
|
||||
decorators=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Linear Algebra Operators
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user