nvfuser torchbench patch (#84411)

1. Patching nvfuser_execute to take aten nvprim fallback when no cuda tensors are provided as inputs
2. Extending support of nvfuser python API on cpu scalar tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84411
Approved by: https://github.com/ngimel, https://github.com/kevinstephano, https://github.com/IvanYashchuk
This commit is contained in:
jjsjann123
2022-09-07 05:22:37 +00:00
committed by PyTorch MergeBot
parent 7c3102f3f0
commit 1a33e944b5
4 changed files with 101 additions and 20 deletions

View File

@ -2,6 +2,7 @@
from functools import partial
from itertools import product
import warnings
from warnings import catch_warnings
import unittest
@ -23,6 +24,7 @@ import torch._refs as refs
if TEST_SCIPY:
import scipy.special
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
class TestPrims(TestCase):
@onlyCUDA
@ -181,6 +183,7 @@ class TestPrims(TestCase):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
from torch._prims.nvfuser_executor import make_nvfuser_fusion
a = torch.randn(3, 3, device=device)
@ -190,8 +193,13 @@ class TestPrims(TestCase):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)()
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
with warnings.catch_warnings(record=True) as caught:
execute(gm, executor="strictly_nvfuser")
# fusion execute with no cuda input is handled by nvprim aten fallback
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
make_nvfuser_fusion(gm)
with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
execute(gm, a, executor="strictly_nvfuser")
@ -473,6 +481,47 @@ class TestPrims(TestCase):
)
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32, torch.float16)
def test_cpu_tensor(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
def _wrapper(t0, t1, cpu_scalar):
return t0 + t1 + cpu_scalar
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((12, 1))
b = make_arg((12, 12))
c = torch.tensor(0.5)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(_wrapper)(a, b, c)
with warnings.catch_warnings(record=True) as caught:
actual = execute(gm, a, b, c, executor="nvfuser")
# cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback
self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
expected = execute(gm, a, b, c, executor="aten")
self.assertEqual(expected, actual)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_add = any(
torch.ops.aten.add.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_add)
with warnings.catch_warnings(record=True) as caught:
nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser")
# cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
self.assertEqual(expected, nvprim_aten_fallback)
@onlyCUDA
@skipCUDAIfRocm
@dtypes(torch.float32)

View File

@ -30,6 +30,7 @@ class nvFuserTensorTemplate:
size: tuple
stride: tuple
dtype: DataType
is_cpu: bool
@dataclass(frozen=True)
@ -41,7 +42,10 @@ def to_nvfuser_template_args(args):
def to_nvfuser(arg):
if isinstance(arg, torch.Tensor):
return nvFuserTensorTemplate(
arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)
arg.size(),
arg.stride(),
getnvFuserDtype(arg.dtype),
arg.is_cpu, # type: ignore[attr-defined]
)
elif isinstance(arg, Number):
return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
@ -134,7 +138,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
def templates_to_nvfuser_inputs(arg):
if isinstance(arg, nvFuserTensorTemplate):
x = fd.define_tensor(arg.size, arg.stride, arg.dtype)
x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu)
return x
elif isinstance(arg, nvFuserScalarTemplate):
x = fd.define_scalar(arg.dtype)
@ -155,21 +159,36 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
def nvfuser_execute(gm: GraphModule, *args):
flat_args, _ = tree_flatten(args)
# Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args)
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
# check for cuda only fusion
if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
(
not isinstance(arg, torch.Tensor)
or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
or arg.is_cuda # type: ignore[attr-defined]
)
for arg in flat_args
):
# Inputs to fusion.execute correspond to the same template/symbolic inputs
# marked with `define_tensor/scalar`
concrete_fusion_inputs = tuple(
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
)
# Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args)
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
return tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
unflatten_spec, # type: ignore[has-type]
)
# Inputs to fusion.execute correspond to the same template/symbolic inputs
# marked with `define_tensor/scalar`
concrete_fusion_inputs = tuple(
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
)
return tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
unflatten_spec, # type: ignore[has-type]
)
else:
warn(
"nvfuser_executor is executed with non-cuda args, fallback to aten executor"
)
return gm.forward(*args)
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):

View File

@ -241,11 +241,13 @@ struct InputTensorRecord : RecordFunctor {
std::vector<size_t> _outputs,
std::vector<int64_t> _symbolic_sizes,
std::vector<bool> _contiguous_info,
NvfDataType _dtype)
NvfDataType _dtype,
bool _is_cpu = false)
: RecordFunctor({}, std::move(_outputs)),
symbolic_sizes(std::move(_symbolic_sizes)),
contiguous_info(std::move(_contiguous_info)),
dtype(_dtype) {}
dtype(_dtype),
is_cpu(_is_cpu) {}
virtual ~InputTensorRecord() = default;
void operator()(FusionDefinition& fd) final {
@ -256,6 +258,12 @@ struct InputTensorRecord : RecordFunctor {
.dtype(dtype)
.build();
if (symbolic_sizes.empty() && is_cpu) {
tv->setCpuScalar(true);
} else {
TORCH_CHECK(!is_cpu, "cpu non-scalar tensor is not supported");
}
fd.setFusionState(outputs.at(0), tv);
fd.addInput(tv);
}
@ -269,6 +277,8 @@ struct InputTensorRecord : RecordFunctor {
std::vector<bool> contiguous_info;
//! Tensor data type.
NvfDataType dtype;
//! Tensor data type.
bool is_cpu;
};
//! Specialized Record Functor for recording FusionDefinition outputs.

View File

@ -117,7 +117,8 @@ void initNvFuserPythonBindings(PyObject* module) {
[](nvfuser::FusionDefinition& self,
std::vector<int64_t> sizes,
std::vector<int64_t> strides,
NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* {
NvfDataType dtype = NvfDataType::Float,
bool is_cpu = false) -> nvfuser::Tensor* {
TORCH_CHECK(
sizes.size() == strides.size(),
"The number of sizes does not match the number of strides.",
@ -159,13 +160,15 @@ void initNvFuserPythonBindings(PyObject* module) {
{out->index},
std::move(maybe_symbolic_sizes),
std::move(contig_info),
dtype));
dtype,
is_cpu));
return out;
},
py::arg("sizes"),
py::arg("strides"),
py::arg("dtype") = NvfDataType::Float,
py::arg("is_cpu") = false,
py::return_value_policy::reference)
.def(
"define_constant",