mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
catches failure on nvprim speculative lowering (#85580)
Fixes #85517 Added a try/catch exception during tracing `get_isolated_graphmodule` inside `_is_func_unsupported_nvfuser`. Stops speculative lowering to nvprim when query errors out. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85580 Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
This commit is contained in:
committed by
PyTorch MergeBot
parent
a807f1987a
commit
cab6ffa0f7
@ -33,6 +33,7 @@ if TEST_SCIPY:
|
||||
import scipy.special
|
||||
|
||||
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
|
||||
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
|
||||
|
||||
class TestPrims(TestCase):
|
||||
@onlyCUDA
|
||||
@ -778,6 +779,32 @@ class TestDecomp(TestCase):
|
||||
)
|
||||
self.assertFalse(includes_aten_to_copy)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
|
||||
# masked_fill decomposition extracts cpu scalar tensor value when
|
||||
# filling out a cuda tensor. This triggers data-dependent control flow
|
||||
# on TorchRefsNvfuser speculative lowering.
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
x = torch.empty(2, 3, device=device).to(dtype=dtype)
|
||||
mask = torch.ones_like(x).bool()
|
||||
y = torch.tensor(0.3) # cpu scalar tensor
|
||||
|
||||
def func(x, mask, y):
|
||||
return torch.masked_fill(x, mask, y)
|
||||
|
||||
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
|
||||
gm = make_fx(func, decomposition_table={})(x, mask, y)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(x, mask, y)
|
||||
# masked_fill decomposition fails inside `get_isolated_graphmodule`
|
||||
self.assertTrue(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
|
||||
|
||||
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
|
||||
def test_decomposition_method_vararg(self, device, dtype, op):
|
||||
# some ops have vararg variants for the methods. this tests it.
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, Sequence, Union
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
@ -68,25 +68,6 @@ def torch_to_refs_map():
|
||||
return r
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def nvfuser_decomp_table():
|
||||
"""
|
||||
decomposition table needed for nvfuser
|
||||
"""
|
||||
aten = torch.ops.aten
|
||||
nvfuser_decompositions: Sequence[
|
||||
Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]
|
||||
] = { # type: ignore[assignment]
|
||||
# AMP calls `to` in C++, which is not handled by torch mapping
|
||||
aten._to_copy,
|
||||
}
|
||||
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
decomp_table = get_decompositions(nvfuser_decompositions)
|
||||
return decomp_table
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def all_prims():
|
||||
"""
|
||||
@ -203,7 +184,17 @@ def _is_node_supported_nvfuser(node):
|
||||
|
||||
def _is_func_unsupported_nvfuser(torch_function_mode, func, args, kwargs):
|
||||
with torch_function_mode:
|
||||
gm = get_isolated_graphmodule(func, args, kwargs)
|
||||
try:
|
||||
gm = get_isolated_graphmodule(func, args, kwargs)
|
||||
except Exception as e:
|
||||
warn(
|
||||
"get_isolated_graphmodule failed on decomposition: "
|
||||
+ func.__name__
|
||||
+ " with error message: "
|
||||
+ str(e)
|
||||
)
|
||||
# returns unsupported when tracing fails.
|
||||
return True
|
||||
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
|
Reference in New Issue
Block a user