mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
nvprim op support runtime checks on dtype compatibility on prims.convert_element_type (#85566)
I'm seeing issue that we lower `_to_copy` into `nvprims.convert_element_type`. In cases where we are casting to a dtype that's not supported by nvfuser, this raise runtime error. I added a quick check in the lowering part where each op can peek at fx.node and make a runtime decision on whether the given op should be lowered to nvprim. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85566 Approved by: https://github.com/IvanYashchuk, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
01292cc9e4
commit
fd553c46f4
@ -235,6 +235,38 @@ class TestPrims(TestCase):
|
||||
out = execute(gm, a, a, a, executor="nvfuser")
|
||||
self.assertEqual(out, (a, a, a))
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.uint8)
|
||||
def test_nvprim_convert_element_type(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
|
||||
|
||||
# initialize input as float32, which is different from `dtype` in the argument.
|
||||
# this ensures that tracing will have a _to_copy node.
|
||||
a = torch.randn(3, 3, device=device, dtype=torch.float32)
|
||||
|
||||
def func(x, dtype):
|
||||
return x.to(dtype).to(x.dtype)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, dtype)
|
||||
execute(gm, a, dtype, executor="nvfuser")
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_to_copy = any(
|
||||
torch.ops.aten._to_copy.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
includes_nvprim_convert_element_type = any(
|
||||
torch.ops.nvprims.convert_element_type.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
nvprim_support_flag = _torch_dtype_to_nvfuser_dtype_map.get(dtype) is not None
|
||||
self.assertEqual(includes_aten_to_copy, not nvprim_support_flag)
|
||||
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
def test_nvfuser_rand_like_fusion(self, device):
|
||||
@ -516,6 +548,8 @@ class TestPrims(TestCase):
|
||||
self.assertFalse(node.target == torch.ops.prims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.aten.add.default)
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
def test_batch_norm_backward_nvprims(self, device, dtype):
|
||||
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims
|
||||
|
@ -6,7 +6,12 @@ from warnings import warn
|
||||
|
||||
import torch
|
||||
import torch.overrides
|
||||
from torch._prims_common import getnvFuserDtype, Number, number_type
|
||||
from torch._prims_common import (
|
||||
_torch_dtype_to_nvfuser_dtype_map,
|
||||
getnvFuserDtype,
|
||||
Number,
|
||||
number_type,
|
||||
)
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
@ -208,6 +213,18 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
|
||||
|
||||
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
|
||||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
||||
# special case to stop lowering to nvprim when converting to an unsupported type
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.nvprims.convert_element_type.default
|
||||
):
|
||||
return (
|
||||
_torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
|
||||
and _torch_dtype_to_nvfuser_dtype_map.get(
|
||||
node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "impl_nvfuser", None) is not None
|
||||
|
Reference in New Issue
Block a user