mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
nvfuser torchbench patch (#84411)
1. Patching nvfuser_execute to take aten nvprim fallback when no cuda tensors are provided as inputs 2. Extending support of nvfuser python API on cpu scalar tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84411 Approved by: https://github.com/ngimel, https://github.com/kevinstephano, https://github.com/IvanYashchuk
This commit is contained in:
committed by
PyTorch MergeBot
parent
7c3102f3f0
commit
1a33e944b5
@ -2,6 +2,7 @@
|
||||
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
import warnings
|
||||
from warnings import catch_warnings
|
||||
import unittest
|
||||
|
||||
@ -23,6 +24,7 @@ import torch._refs as refs
|
||||
if TEST_SCIPY:
|
||||
import scipy.special
|
||||
|
||||
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
|
||||
|
||||
class TestPrims(TestCase):
|
||||
@onlyCUDA
|
||||
@ -181,6 +183,7 @@ class TestPrims(TestCase):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
from torch._prims.nvfuser_executor import make_nvfuser_fusion
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
@ -190,8 +193,13 @@ class TestPrims(TestCase):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)()
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
execute(gm, executor="strictly_nvfuser")
|
||||
# fusion execute with no cuda input is handled by nvprim aten fallback
|
||||
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
|
||||
make_nvfuser_fusion(gm)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
|
||||
execute(gm, a, executor="strictly_nvfuser")
|
||||
@ -473,6 +481,47 @@ class TestPrims(TestCase):
|
||||
)
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
def test_cpu_tensor(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
def _wrapper(t0, t1, cpu_scalar):
|
||||
return t0 + t1 + cpu_scalar
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((12, 1))
|
||||
b = make_arg((12, 12))
|
||||
c = torch.tensor(0.5)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(_wrapper)(a, b, c)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
actual = execute(gm, a, b, c, executor="nvfuser")
|
||||
# cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback
|
||||
self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
expected = execute(gm, a, b, c, executor="aten")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_add = any(
|
||||
torch.ops.aten.add.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_add)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser")
|
||||
# cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning
|
||||
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
self.assertEqual(expected, nvprim_aten_fallback)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float32)
|
||||
|
@ -30,6 +30,7 @@ class nvFuserTensorTemplate:
|
||||
size: tuple
|
||||
stride: tuple
|
||||
dtype: DataType
|
||||
is_cpu: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -41,7 +42,10 @@ def to_nvfuser_template_args(args):
|
||||
def to_nvfuser(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return nvFuserTensorTemplate(
|
||||
arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)
|
||||
arg.size(),
|
||||
arg.stride(),
|
||||
getnvFuserDtype(arg.dtype),
|
||||
arg.is_cpu, # type: ignore[attr-defined]
|
||||
)
|
||||
elif isinstance(arg, Number):
|
||||
return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
|
||||
@ -134,7 +138,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
|
||||
def templates_to_nvfuser_inputs(arg):
|
||||
if isinstance(arg, nvFuserTensorTemplate):
|
||||
x = fd.define_tensor(arg.size, arg.stride, arg.dtype)
|
||||
x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu)
|
||||
return x
|
||||
elif isinstance(arg, nvFuserScalarTemplate):
|
||||
x = fd.define_scalar(arg.dtype)
|
||||
@ -155,21 +159,36 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
def nvfuser_execute(gm: GraphModule, *args):
|
||||
flat_args, _ = tree_flatten(args)
|
||||
|
||||
# Construction of the fusion is expensive and cached based on the GraphModule
|
||||
# and symbolic nvFuser args.
|
||||
nv_template_args = to_nvfuser_template_args(flat_args)
|
||||
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
|
||||
# check for cuda only fusion
|
||||
if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
|
||||
(
|
||||
not isinstance(arg, torch.Tensor)
|
||||
or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
|
||||
or arg.is_cuda # type: ignore[attr-defined]
|
||||
)
|
||||
for arg in flat_args
|
||||
):
|
||||
|
||||
# Inputs to fusion.execute correspond to the same template/symbolic inputs
|
||||
# marked with `define_tensor/scalar`
|
||||
concrete_fusion_inputs = tuple(
|
||||
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
|
||||
)
|
||||
# Construction of the fusion is expensive and cached based on the GraphModule
|
||||
# and symbolic nvFuser args.
|
||||
nv_template_args = to_nvfuser_template_args(flat_args)
|
||||
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
|
||||
|
||||
return tree_unflatten(
|
||||
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
|
||||
unflatten_spec, # type: ignore[has-type]
|
||||
)
|
||||
# Inputs to fusion.execute correspond to the same template/symbolic inputs
|
||||
# marked with `define_tensor/scalar`
|
||||
concrete_fusion_inputs = tuple(
|
||||
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
|
||||
)
|
||||
|
||||
return tree_unflatten(
|
||||
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
|
||||
unflatten_spec, # type: ignore[has-type]
|
||||
)
|
||||
else:
|
||||
warn(
|
||||
"nvfuser_executor is executed with non-cuda args, fallback to aten executor"
|
||||
)
|
||||
return gm.forward(*args)
|
||||
|
||||
|
||||
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
|
||||
|
@ -241,11 +241,13 @@ struct InputTensorRecord : RecordFunctor {
|
||||
std::vector<size_t> _outputs,
|
||||
std::vector<int64_t> _symbolic_sizes,
|
||||
std::vector<bool> _contiguous_info,
|
||||
NvfDataType _dtype)
|
||||
NvfDataType _dtype,
|
||||
bool _is_cpu = false)
|
||||
: RecordFunctor({}, std::move(_outputs)),
|
||||
symbolic_sizes(std::move(_symbolic_sizes)),
|
||||
contiguous_info(std::move(_contiguous_info)),
|
||||
dtype(_dtype) {}
|
||||
dtype(_dtype),
|
||||
is_cpu(_is_cpu) {}
|
||||
virtual ~InputTensorRecord() = default;
|
||||
|
||||
void operator()(FusionDefinition& fd) final {
|
||||
@ -256,6 +258,12 @@ struct InputTensorRecord : RecordFunctor {
|
||||
.dtype(dtype)
|
||||
.build();
|
||||
|
||||
if (symbolic_sizes.empty() && is_cpu) {
|
||||
tv->setCpuScalar(true);
|
||||
} else {
|
||||
TORCH_CHECK(!is_cpu, "cpu non-scalar tensor is not supported");
|
||||
}
|
||||
|
||||
fd.setFusionState(outputs.at(0), tv);
|
||||
fd.addInput(tv);
|
||||
}
|
||||
@ -269,6 +277,8 @@ struct InputTensorRecord : RecordFunctor {
|
||||
std::vector<bool> contiguous_info;
|
||||
//! Tensor data type.
|
||||
NvfDataType dtype;
|
||||
//! Tensor data type.
|
||||
bool is_cpu;
|
||||
};
|
||||
|
||||
//! Specialized Record Functor for recording FusionDefinition outputs.
|
||||
|
@ -117,7 +117,8 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
[](nvfuser::FusionDefinition& self,
|
||||
std::vector<int64_t> sizes,
|
||||
std::vector<int64_t> strides,
|
||||
NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* {
|
||||
NvfDataType dtype = NvfDataType::Float,
|
||||
bool is_cpu = false) -> nvfuser::Tensor* {
|
||||
TORCH_CHECK(
|
||||
sizes.size() == strides.size(),
|
||||
"The number of sizes does not match the number of strides.",
|
||||
@ -159,13 +160,15 @@ void initNvFuserPythonBindings(PyObject* module) {
|
||||
{out->index},
|
||||
std::move(maybe_symbolic_sizes),
|
||||
std::move(contig_info),
|
||||
dtype));
|
||||
dtype,
|
||||
is_cpu));
|
||||
|
||||
return out;
|
||||
},
|
||||
py::arg("sizes"),
|
||||
py::arg("strides"),
|
||||
py::arg("dtype") = NvfDataType::Float,
|
||||
py::arg("is_cpu") = false,
|
||||
py::return_value_policy::reference)
|
||||
.def(
|
||||
"define_constant",
|
||||
|
Reference in New Issue
Block a user