nvprim op support runtime checks on dtype compatibility on prims.convert_element_type (#85566)

I'm seeing issue that we lower `_to_copy` into `nvprims.convert_element_type`. In cases where we are casting to a dtype that's not supported by nvfuser, this raise runtime error.

I added a quick check in the lowering part where each op can peek at fx.node and make a runtime decision on whether the given op should be lowered to nvprim.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85566
Approved by: https://github.com/IvanYashchuk, https://github.com/ngimel
This commit is contained in:
jjsjann123
2022-09-30 23:19:25 +00:00
committed by PyTorch MergeBot
parent 01292cc9e4
commit fd553c46f4
2 changed files with 52 additions and 1 deletions

View File

@ -235,6 +235,38 @@ class TestPrims(TestCase):
out = execute(gm, a, a, a, executor="nvfuser")
self.assertEqual(out, (a, a, a))
@onlyCUDA
@dtypes(torch.float16, torch.uint8)
def test_nvprim_convert_element_type(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
# initialize input as float32, which is different from `dtype` in the argument.
# this ensures that tracing will have a _to_copy node.
a = torch.randn(3, 3, device=device, dtype=torch.float32)
def func(x, dtype):
return x.to(dtype).to(x.dtype)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a, dtype)
execute(gm, a, dtype, executor="nvfuser")
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_to_copy = any(
torch.ops.aten._to_copy.default == node.target
for node in call_function_nodes
)
includes_nvprim_convert_element_type = any(
torch.ops.nvprims.convert_element_type.default == node.target
for node in call_function_nodes
)
nvprim_support_flag = _torch_dtype_to_nvfuser_dtype_map.get(dtype) is not None
self.assertEqual(includes_aten_to_copy, not nvprim_support_flag)
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
@onlyCUDA
@skipCUDAIfRocm
def test_nvfuser_rand_like_fusion(self, device):
@ -516,6 +548,8 @@ class TestPrims(TestCase):
self.assertFalse(node.target == torch.ops.prims.add.default)
self.assertFalse(node.target == torch.ops.aten.add.default)
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA
@dtypes(torch.float32, torch.float16)
def test_batch_norm_backward_nvprims(self, device, dtype):
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims

View File

@ -6,7 +6,12 @@ from warnings import warn
import torch
import torch.overrides
from torch._prims_common import getnvFuserDtype, Number, number_type
from torch._prims_common import (
_torch_dtype_to_nvfuser_dtype_map,
getnvFuserDtype,
Number,
number_type,
)
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
@ -208,6 +213,18 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# special case to stop lowering to nvprim when converting to an unsupported type
if (
node.op == "call_function"
and node.target == torch.ops.nvprims.convert_element_type.default
):
return (
_torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
and _torch_dtype_to_nvfuser_dtype_map.get(
node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
)
is not None
)
return (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is not None