From 1a33e944b58a75efe6154f1d02a32b80b7661edf Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 7 Sep 2022 05:22:37 +0000 Subject: [PATCH] nvfuser torchbench patch (#84411) 1. Patching nvfuser_execute to take aten nvprim fallback when no cuda tensors are provided as inputs 2. Extending support of nvfuser python API on cpu scalar tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84411 Approved by: https://github.com/ngimel, https://github.com/kevinstephano, https://github.com/IvanYashchuk --- test/test_prims.py | 51 ++++++++++++++++++- torch/_prims/nvfuser_executor.py | 49 ++++++++++++------ .../cuda/python_frontend/fusion_record.h | 14 ++++- .../cuda/python_frontend/python_bindings.cpp | 7 ++- 4 files changed, 101 insertions(+), 20 deletions(-) diff --git a/test/test_prims.py b/test/test_prims.py index c3942714bbbe..cf2d721cd1fd 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -2,6 +2,7 @@ from functools import partial from itertools import product +import warnings from warnings import catch_warnings import unittest @@ -23,6 +24,7 @@ import torch._refs as refs if TEST_SCIPY: import scipy.special +NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor" class TestPrims(TestCase): @onlyCUDA @@ -181,6 +183,7 @@ class TestPrims(TestCase): from torch._prims.context import TorchRefsNvfuserCapabilityMode from torch.fx.experimental.proxy_tensor import make_fx from torch._prims.executor import execute + from torch._prims.nvfuser_executor import make_nvfuser_fusion a = torch.randn(3, 3, device=device) @@ -190,8 +193,13 @@ class TestPrims(TestCase): with TorchRefsNvfuserCapabilityMode(): gm = make_fx(func)() - with self.assertRaisesRegex(AssertionError, "There must be at least one argument"): + with warnings.catch_warnings(record=True) as caught: execute(gm, executor="strictly_nvfuser") + # fusion execute with no cuda input is handled by nvprim aten fallback + self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) + + with self.assertRaisesRegex(AssertionError, "There must be at least one argument"): + make_nvfuser_fusion(gm) with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"): execute(gm, a, executor="strictly_nvfuser") @@ -473,6 +481,47 @@ class TestPrims(TestCase): ) self.assertTrue(includes_nvprims_var_mean) + @onlyCUDA + @skipCUDAIfRocm + @dtypes(torch.float32, torch.float16) + def test_cpu_tensor(self, device, dtype): + from torch.fx.experimental.proxy_tensor import make_fx + from torch._prims.context import TorchRefsNvfuserCapabilityMode + from torch._prims.executor import execute + + def _wrapper(t0, t1, cpu_scalar): + return t0 + t1 + cpu_scalar + + make_arg = partial(make_tensor, device=device, dtype=dtype) + a = make_arg((12, 1)) + b = make_arg((12, 12)) + c = torch.tensor(0.5) + + with TorchRefsNvfuserCapabilityMode(): + gm = make_fx(_wrapper)(a, b, c) + + with warnings.catch_warnings(record=True) as caught: + actual = execute(gm, a, b, c, executor="nvfuser") + # cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback + self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) + + expected = execute(gm, a, b, c, executor="aten") + self.assertEqual(expected, actual) + + call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes)) + includes_aten_add = any( + torch.ops.aten.add.default == node.target + for node in call_function_nodes + ) + self.assertFalse(includes_aten_add) + + with warnings.catch_warnings(record=True) as caught: + nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser") + # cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning + self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught)) + + self.assertEqual(expected, nvprim_aten_fallback) + @onlyCUDA @skipCUDAIfRocm @dtypes(torch.float32) diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py index 72bc71ef8c8d..6fa5f97b76cd 100644 --- a/torch/_prims/nvfuser_executor.py +++ b/torch/_prims/nvfuser_executor.py @@ -30,6 +30,7 @@ class nvFuserTensorTemplate: size: tuple stride: tuple dtype: DataType + is_cpu: bool @dataclass(frozen=True) @@ -41,7 +42,10 @@ def to_nvfuser_template_args(args): def to_nvfuser(arg): if isinstance(arg, torch.Tensor): return nvFuserTensorTemplate( - arg.size(), arg.stride(), getnvFuserDtype(arg.dtype) + arg.size(), + arg.stride(), + getnvFuserDtype(arg.dtype), + arg.is_cpu, # type: ignore[attr-defined] ) elif isinstance(arg, Number): return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg))) @@ -134,7 +138,7 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): def templates_to_nvfuser_inputs(arg): if isinstance(arg, nvFuserTensorTemplate): - x = fd.define_tensor(arg.size, arg.stride, arg.dtype) + x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu) return x elif isinstance(arg, nvFuserScalarTemplate): x = fd.define_scalar(arg.dtype) @@ -155,21 +159,36 @@ def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates): def nvfuser_execute(gm: GraphModule, *args): flat_args, _ = tree_flatten(args) - # Construction of the fusion is expensive and cached based on the GraphModule - # and symbolic nvFuser args. - nv_template_args = to_nvfuser_template_args(flat_args) - fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc] + # check for cuda only fusion + if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined] + ( + not isinstance(arg, torch.Tensor) + or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined] + or arg.is_cuda # type: ignore[attr-defined] + ) + for arg in flat_args + ): - # Inputs to fusion.execute correspond to the same template/symbolic inputs - # marked with `define_tensor/scalar` - concrete_fusion_inputs = tuple( - arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) - ) + # Construction of the fusion is expensive and cached based on the GraphModule + # and symbolic nvFuser args. + nv_template_args = to_nvfuser_template_args(flat_args) + fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc] - return tree_unflatten( - fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] - unflatten_spec, # type: ignore[has-type] - ) + # Inputs to fusion.execute correspond to the same template/symbolic inputs + # marked with `define_tensor/scalar` + concrete_fusion_inputs = tuple( + arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) + ) + + return tree_unflatten( + fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] + unflatten_spec, # type: ignore[has-type] + ) + else: + warn( + "nvfuser_executor is executed with non-cuda args, fallback to aten executor" + ) + return gm.forward(*args) class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport): diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h index 6cd645dd7c03..4616bd114931 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h +++ b/torch/csrc/jit/codegen/cuda/python_frontend/fusion_record.h @@ -241,11 +241,13 @@ struct InputTensorRecord : RecordFunctor { std::vector _outputs, std::vector _symbolic_sizes, std::vector _contiguous_info, - NvfDataType _dtype) + NvfDataType _dtype, + bool _is_cpu = false) : RecordFunctor({}, std::move(_outputs)), symbolic_sizes(std::move(_symbolic_sizes)), contiguous_info(std::move(_contiguous_info)), - dtype(_dtype) {} + dtype(_dtype), + is_cpu(_is_cpu) {} virtual ~InputTensorRecord() = default; void operator()(FusionDefinition& fd) final { @@ -256,6 +258,12 @@ struct InputTensorRecord : RecordFunctor { .dtype(dtype) .build(); + if (symbolic_sizes.empty() && is_cpu) { + tv->setCpuScalar(true); + } else { + TORCH_CHECK(!is_cpu, "cpu non-scalar tensor is not supported"); + } + fd.setFusionState(outputs.at(0), tv); fd.addInput(tv); } @@ -269,6 +277,8 @@ struct InputTensorRecord : RecordFunctor { std::vector contiguous_info; //! Tensor data type. NvfDataType dtype; + //! Tensor data type. + bool is_cpu; }; //! Specialized Record Functor for recording FusionDefinition outputs. diff --git a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp index d2ddd8fd937c..2d211560f70a 100644 --- a/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp +++ b/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp @@ -117,7 +117,8 @@ void initNvFuserPythonBindings(PyObject* module) { [](nvfuser::FusionDefinition& self, std::vector sizes, std::vector strides, - NvfDataType dtype = NvfDataType::Float) -> nvfuser::Tensor* { + NvfDataType dtype = NvfDataType::Float, + bool is_cpu = false) -> nvfuser::Tensor* { TORCH_CHECK( sizes.size() == strides.size(), "The number of sizes does not match the number of strides.", @@ -159,13 +160,15 @@ void initNvFuserPythonBindings(PyObject* module) { {out->index}, std::move(maybe_symbolic_sizes), std::move(contig_info), - dtype)); + dtype, + is_cpu)); return out; }, py::arg("sizes"), py::arg("strides"), py::arg("dtype") = NvfDataType::Float, + py::arg("is_cpu") = false, py::return_value_policy::reference) .def( "define_constant",