mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
catches failure on nvprim speculative lowering (#85580)
Fixes #85517 Added a try/catch exception during tracing `get_isolated_graphmodule` inside `_is_func_unsupported_nvfuser`. Stops speculative lowering to nvprim when query errors out. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85580 Approved by: https://github.com/mruberry, https://github.com/IvanYashchuk
This commit is contained in:
committed by
PyTorch MergeBot
parent
a807f1987a
commit
cab6ffa0f7
@ -33,6 +33,7 @@ if TEST_SCIPY:
|
||||
import scipy.special
|
||||
|
||||
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
|
||||
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
|
||||
|
||||
class TestPrims(TestCase):
|
||||
@onlyCUDA
|
||||
@ -778,6 +779,32 @@ class TestDecomp(TestCase):
|
||||
)
|
||||
self.assertFalse(includes_aten_to_copy)
|
||||
|
||||
@onlyCUDA
|
||||
@skipCUDAIfRocm
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
|
||||
# masked_fill decomposition extracts cpu scalar tensor value when
|
||||
# filling out a cuda tensor. This triggers data-dependent control flow
|
||||
# on TorchRefsNvfuser speculative lowering.
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
x = torch.empty(2, 3, device=device).to(dtype=dtype)
|
||||
mask = torch.ones_like(x).bool()
|
||||
y = torch.tensor(0.3) # cpu scalar tensor
|
||||
|
||||
def func(x, mask, y):
|
||||
return torch.masked_fill(x, mask, y)
|
||||
|
||||
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
|
||||
gm = make_fx(func, decomposition_table={})(x, mask, y)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(x, mask, y)
|
||||
# masked_fill decomposition fails inside `get_isolated_graphmodule`
|
||||
self.assertTrue(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
|
||||
|
||||
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
|
||||
def test_decomposition_method_vararg(self, device, dtype, op):
|
||||
# some ops have vararg variants for the methods. this tests it.
|
||||
|
Reference in New Issue
Block a user