mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove dynamo+nvfuser (#105789)
This PR removes unmaintained Dynamo+nvFuser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/105789 Approved by: https://github.com/jansel, https://github.com/jjsjann123, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad22f0ffb4
commit
6030151d37
@ -52,10 +52,6 @@ Some of the most commonly used backends include:
|
||||
- Description
|
||||
* - ``torch.compile(m, backend="inductor")``
|
||||
- Uses the TorchInductor backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
|
||||
* - ``torch.compile(m, backend="aot_ts_nvfuser")``
|
||||
- nvFuser with AOT Autograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* - ``torch.compile(m, backend="nvprims_nvfuser")``
|
||||
- Tracing with nvFuser and its primitives. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
|
||||
* - ``torch.compile(m, backend="cudagraphs")``
|
||||
- CUDA graphs with AOT Autograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__
|
||||
|
||||
|
@ -93,7 +93,7 @@ hub.
|
||||
|
||||
And that is not the only available backend, you can run in a REPL
|
||||
``torch.compile.list_backends()`` to see all the available backends. Try out the
|
||||
``cudagraphs`` or ``nvfuser`` next as inspiration.
|
||||
``cudagraphs`` next as inspiration.
|
||||
|
||||
Using a pretrained model
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -202,7 +202,7 @@ If the error does not occur with the ``"eager"`` backend, then the
|
||||
backend compiler is the source of the error (`example
|
||||
error <https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de%5D>`__).
|
||||
There are `different choices <./torch.compiler.rst>`__
|
||||
for backend compilers for TorchDynamo, with TorchInductor or nvfuser
|
||||
for backend compilers for TorchDynamo, with TorchInductor
|
||||
fitting the needs of most users. This section focuses on TorchInductor
|
||||
as the motivating example, but some tools can also be used with other
|
||||
backend compilers.
|
||||
|
@ -10,7 +10,6 @@ from torch._dynamo.backends.debugging import ExplainWithBackend
|
||||
from torch._dynamo.backends.onnxrt import has_onnxruntime
|
||||
from torch._dynamo.backends.tvm import has_tvm
|
||||
from torch._dynamo.testing import same
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
|
||||
@ -123,21 +122,6 @@ class TestOptimizations(torch._dynamo.test_case.TestCase):
|
||||
def test_aot_cudagraphs(self):
|
||||
self._check_backend_works("cudagraphs")
|
||||
|
||||
@skipIfRocm
|
||||
@requires_cuda()
|
||||
def test_aot_ts_nvfuser(self):
|
||||
self._check_backend_works("aot_ts_nvfuser")
|
||||
|
||||
@requires_cuda()
|
||||
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
|
||||
def test_nvprims_nvfuser(self):
|
||||
self._check_backend_works("nvprims_nvfuser")
|
||||
|
||||
@requires_cuda()
|
||||
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
|
||||
def test_nvprims_aten(self):
|
||||
self._check_backend_works("nvprims_aten")
|
||||
|
||||
@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
|
||||
def test_onnxrt(self):
|
||||
self._check_backend_works("onnxrt")
|
||||
|
@ -1,11 +0,0 @@
|
||||
# Owner(s): ["module: nvfuser"]
|
||||
|
||||
try:
|
||||
from _nvfuser.test_dynamo import * # noqa: F403,F401
|
||||
except ImportError:
|
||||
def run_tests():
|
||||
return
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -484,26 +484,9 @@ class TestCommon(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyCUDA
|
||||
@ops(python_ref_db)
|
||||
@parametrize('executor', ['aten', 'nvfuser'])
|
||||
@parametrize('executor', ['aten',])
|
||||
@skipIfTorchInductor("Takes too long for inductor")
|
||||
def test_python_ref_executor(self, device, dtype, op, executor):
|
||||
# TODO: Not all dtypes are supported with nvfuser
|
||||
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
|
||||
if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map:
|
||||
raise unittest.SkipTest(f"nvfuser doesn't support dtype {dtype}")
|
||||
|
||||
# nvFuser tests are rather slow so we only run int32 and float32 types
|
||||
if executor == "nvfuser" and dtype not in [torch.int32, torch.float32]:
|
||||
raise unittest.SkipTest("skipped for speed")
|
||||
|
||||
if executor == "nvfuser" and not op.supports_nvfuser:
|
||||
raise unittest.SkipTest(f"{op.name} doesn't support nvfuser")
|
||||
|
||||
# nvFuser doesn't support reduction operations on 0-dim tensors yet
|
||||
skip_zero_dim = False
|
||||
if executor == "nvfuser" and isinstance(op, ReductionPythonRefInfo):
|
||||
skip_zero_dim = True
|
||||
|
||||
# skip zero-dim tensors for some composites of reduction operations and view
|
||||
skip_zero_dim_ops = [
|
||||
"_refs.logsumexp",
|
||||
@ -513,25 +496,16 @@ class TestCommon(TestCase):
|
||||
"_refs.sum_to_size",
|
||||
"ops.nvprims.view",
|
||||
]
|
||||
if executor == "nvfuser" and op.name in skip_zero_dim_ops:
|
||||
skip_zero_dim = True
|
||||
|
||||
from torch._prims.executor import make_traced
|
||||
from copy import copy
|
||||
op = copy(op)
|
||||
executor = "strictly_nvfuser" if executor == "nvfuser" else executor
|
||||
op.op = partial(make_traced(op.op), executor=executor)
|
||||
self._ref_test_helper(
|
||||
contextlib.nullcontext,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
skip_zero_numel=("nvfuser" in executor), # nvfuser doesn't support zero-sized tensors
|
||||
skip_zero_dim=skip_zero_dim,
|
||||
skip_bfloat=("nvfuser" in executor), # nvfuser doesn't support bfloat tensors for pre-11 cuda TK
|
||||
# # nvfuser doesn't support view consistency
|
||||
# https://github.com/pytorch/pytorch/issues/84863
|
||||
skip_view_consistency=("nvfuser" in executor),
|
||||
)
|
||||
|
||||
@skipMeta
|
||||
|
@ -2,14 +2,12 @@
|
||||
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
import warnings
|
||||
from warnings import catch_warnings
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
|
||||
set_default_dtype, skipCUDAMemoryLeakCheckIf)
|
||||
set_default_dtype)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
@ -28,7 +26,6 @@ import torch._prims as prims
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch._prims.executor import make_traced
|
||||
import torch._refs as refs
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
|
||||
if TEST_SCIPY:
|
||||
@ -47,7 +44,7 @@ class TestPrims(TestCase):
|
||||
traced = make_traced(_wrapper)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
for executor in ('aten', 'strictly_nvfuser'):
|
||||
for executor in ('aten',):
|
||||
fn = partial(traced, executor=executor)
|
||||
# Same shape
|
||||
shape = (5, 5)
|
||||
@ -87,20 +84,6 @@ class TestPrims(TestCase):
|
||||
self.assertEqual(result.shape, b.shape)
|
||||
self.assertEqual(a.unsqueeze(2), result)
|
||||
|
||||
# FIXME: This test exposes an issue in nvfuser
|
||||
# Adds outermost, expands, and unsqueezes
|
||||
"""
|
||||
a = make_arg((1, 2, 3))
|
||||
b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0)
|
||||
result = fn(a, b, (1, 3, 4))
|
||||
|
||||
self.assertEqual(result.shape, b.shape)
|
||||
a.unsqueeze_(3)
|
||||
a.unsqueeze_(1)
|
||||
a.unsqueeze_(0)
|
||||
self.assertEqual(a.expand_as(result), result)
|
||||
"""
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_broadcast_in_dim_sum(self, device, dtype):
|
||||
@ -112,7 +95,7 @@ class TestPrims(TestCase):
|
||||
traced = make_traced(_wrapper)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
for executor in ('aten', 'strictly_nvfuser'):
|
||||
for executor in ('aten',):
|
||||
fn = partial(traced, executor=executor)
|
||||
shape = (5, 5)
|
||||
a = make_arg(shape)
|
||||
@ -171,340 +154,6 @@ class TestPrims(TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(t, start, end)
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_impl_is_used(self, device):
|
||||
# This test is to ensure that when the nvfuser implementation exists it is used
|
||||
# Assuming one-to-one mapping between prims and nvfuser implementations
|
||||
# This test is not intended to test the correctness of the nvfuser implementation
|
||||
try:
|
||||
from nvfuser import FusionDefinition as fd
|
||||
except ImportError:
|
||||
from nvfuser._C import FusionDefinition as fd
|
||||
|
||||
|
||||
prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
|
||||
ops_without_nvfuser_impl = {
|
||||
name
|
||||
for name in prim_nvfuser_ops
|
||||
if getattr(torch.ops.nvprims, name, None) is None
|
||||
}
|
||||
assert (
|
||||
len(ops_without_nvfuser_impl) == 0
|
||||
), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ",
|
||||
"while there exists nvfuser implementations for them.")
|
||||
|
||||
def test_skip_ops_nvfuser_prims_mode(self, device):
|
||||
# This test verifies that the NvfuserPrimsMode skips the specified
|
||||
# functions. Skipping a function means that it's not converted into
|
||||
# nvprims counterparts.
|
||||
from torch._prims.context import NvfuserPrimsMode
|
||||
|
||||
a = make_tensor(5, 5, device=device, dtype=torch.float32)
|
||||
|
||||
def func(a):
|
||||
return torch.ops.prims.sin.default(a)
|
||||
|
||||
skip_ops = {"prims.sin.default", }
|
||||
with NvfuserPrimsMode(skip_ops=skip_ops):
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
includes_any_prims_sin = any(
|
||||
node.target == torch.ops.prims.sin.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(includes_any_prims_sin)
|
||||
include_any_nvprims_sin = any(
|
||||
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(include_any_nvprims_sin)
|
||||
|
||||
def test_skip_ops_nvfuser_capability_mode(self, device):
|
||||
# This test verifies that the NvfuserCapabilityMode skips the specified
|
||||
# functions. Skipping a function means that specific
|
||||
# reference/decomposition is not traced and there's no attempt to lower
|
||||
# it to nvprims.
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
a = make_tensor(5, 5, device=device, dtype=torch.float32)
|
||||
|
||||
def func(a):
|
||||
return torch.sin(a)
|
||||
|
||||
skip_ops = {"torch.sin", }
|
||||
with TorchRefsNvfuserCapabilityMode(skip_ops=skip_ops):
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
includes_any_aten_sin = any(
|
||||
node.target == torch.ops.aten.sin.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(includes_any_aten_sin)
|
||||
include_any_nvprims_sin = any(
|
||||
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(include_any_nvprims_sin)
|
||||
|
||||
def test_partitioner_tuple_output(self, device):
|
||||
# This test verifies that the partitioner doesn't segment on nodes with
|
||||
# tuple outputs.
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
|
||||
|
||||
a = make_tensor(5, 3, 3, device=device, dtype=torch.float32)
|
||||
|
||||
def func(x):
|
||||
xx = torch.ops.nvprims.add(x, 1)
|
||||
var, mean = torch.ops.nvprims.var_mean(x, correction=0)
|
||||
var_cos = torch.ops.nvprims.cos(var)
|
||||
mean_sin = torch.ops.nvprims.sin(mean)
|
||||
return torch.ops.nvprims.add(var_cos, mean_sin)
|
||||
|
||||
gm = make_fx(func)(a)
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
gm, supported_ops, allows_single_node_partition=False
|
||||
)
|
||||
partitions = partitioner.propose_partitions()
|
||||
self.assertEqual(len(partitions), 1)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_full(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
def func1(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device),)
|
||||
|
||||
def func2(size, value, b):
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
b_sin = b.sin()
|
||||
return (torch.add(a, b_sin),)
|
||||
|
||||
def func3(size, value, b):
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b)
|
||||
|
||||
def func4(size, value, b):
|
||||
b_sin = b.sin()
|
||||
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
|
||||
|
||||
def func5(size, value, b):
|
||||
b_sin = b.sin()
|
||||
a = torch.full(size, value, dtype=dtype, device=device)
|
||||
a_sin = a.sin()
|
||||
return (a, b_sin, a_sin)
|
||||
|
||||
for func in (func1, func3, func2, func3, func4, func5):
|
||||
size = (3, 3)
|
||||
value = 10
|
||||
b = torch.randn(*size, dtype=dtype, device=device)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(size, value, b)
|
||||
|
||||
out = execute(gm, size, value, b, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(size, value, b))
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_empty_fusion(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(a, b, c):
|
||||
return (a, b, c)
|
||||
|
||||
gm = make_fx(func)(a, a, a)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"):
|
||||
execute(gm, a, a, a, executor="strictly_nvfuser")
|
||||
|
||||
# Should pass with partitioned executor
|
||||
out = execute(gm, a, a, a, executor="nvfuser")
|
||||
self.assertEqual(out, (a, a, a))
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.uint8)
|
||||
def test_nvprim_convert_element_type(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
|
||||
|
||||
# initialize input as float32, which is different from `dtype` in the argument.
|
||||
# this ensures that tracing will have a _to_copy node.
|
||||
a = torch.randn(3, 3, device=device, dtype=torch.float32)
|
||||
|
||||
def func(x, dtype):
|
||||
return x.to(dtype).to(x.dtype)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, dtype)
|
||||
execute(gm, a, dtype, executor="nvfuser")
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_to_copy = any(
|
||||
torch.ops.aten._to_copy.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
includes_nvprim_convert_element_type = any(
|
||||
torch.ops.nvprims.convert_element_type.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
nvprim_support_flag = _torch_dtype_to_nvfuser_dtype_map.get(dtype) is not None
|
||||
self.assertEqual(includes_aten_to_copy, not nvprim_support_flag)
|
||||
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_rand_like_fusion(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.rand_like(a)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
out = execute(gm, a, executor="strictly_nvfuser")
|
||||
self.assertEqual(out.size(), a.size())
|
||||
|
||||
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
|
||||
@onlyCUDA
|
||||
def test_nvfuser_no_args(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
from torch._prims.nvfuser_executor import make_nvfuser_fusion
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func():
|
||||
return torch.sigmoid(a)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)()
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
execute(gm, executor="strictly_nvfuser")
|
||||
# fusion execute with no cuda input is handled by nvprim aten fallback
|
||||
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
|
||||
make_nvfuser_fusion(gm)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
|
||||
execute(gm, a, executor="strictly_nvfuser")
|
||||
|
||||
# Should pass with partitioned executor
|
||||
out = execute(gm, executor="nvfuser")
|
||||
self.assertEqual(out, func())
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_constant_tensors(self, device):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
b = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(b):
|
||||
return a + b
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(b)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "not supported yet"):
|
||||
execute(gm, b, executor="strictly_nvfuser")
|
||||
|
||||
# Should pass with partitioned executor
|
||||
out = execute(gm, b, executor="nvfuser")
|
||||
self.assertEqual(out, gm(b))
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_executor_cached_noncontiguous(self, device):
|
||||
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.sigmoid(a)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
# First run to create the cache
|
||||
execute(gm, a, executor="strictly_nvfuser")
|
||||
|
||||
# a.mT is noncontiguous, but it shouldn't affect correctness
|
||||
expected = execute(gm, a.mT, executor="aten")
|
||||
for use_python_cache in [True, False]:
|
||||
params = {"use_python_fusion_cache": use_python_cache}
|
||||
actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_nvfuser_capability_context(self, device):
|
||||
# This test is to ensure that the torch calls are replaced with refs
|
||||
# based on the nvfuser+prims capability
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
# It's assumed that digamma is not supported by nvfuser
|
||||
# If it's ever supported, this test will need to be updated
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
a = torch.randn(3, 3, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.digamma(a)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
# Check that the torch.digamma is not replaced with torch.ops.prims.digamma
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_digamma = any(
|
||||
torch.ops.aten.digamma.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
includes_prims_digamma = any(
|
||||
torch.ops.prims.digamma.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_aten_digamma)
|
||||
self.assertFalse(includes_prims_digamma)
|
||||
|
||||
# Check mixed case, sigmoid is replaced with refs, but digamma is not
|
||||
def func(a):
|
||||
return torch.sigmoid(torch.digamma(a))
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_sigmoid = any(
|
||||
torch.ops.aten.sigmoid.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
includes_prims_digamma = any(
|
||||
torch.ops.prims.digamma.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
includes_nvprims_exp = any(
|
||||
torch.ops.nvprims.exp.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_sigmoid)
|
||||
self.assertFalse(includes_prims_digamma)
|
||||
self.assertTrue(includes_nvprims_exp)
|
||||
|
||||
|
||||
def test_aten_overload_to_prims(self, device):
|
||||
# This test is to ensure that the torch.ops.aten calls are replaced with refs
|
||||
@ -526,325 +175,6 @@ class TestPrims(TestCase):
|
||||
)
|
||||
self.assertTrue(all_prims_namespace)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_executor_parameters(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.ops.nvprims.add(a, a)
|
||||
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
expected = execute(gm, a, executor="aten")
|
||||
# Shouldn't raise an error because unuseful parameters are ignored
|
||||
params_dicts = [None, {}, {"none": None}]
|
||||
for params in params_dicts:
|
||||
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# Check caching parameter
|
||||
for use_cache in [True, False]:
|
||||
params = {"use_python_fusion_cache": use_cache}
|
||||
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
# Check allow_single_op_fusion parameter
|
||||
for allow_single_op_fusion in [True, False]:
|
||||
params = {"allow_single_op_fusion": allow_single_op_fusion}
|
||||
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_executor_partitioned(self, device):
|
||||
# This test is to ensure that nvfuser partitioned executor works correctly
|
||||
# It's assumed that digamma is not supported by nvfuser
|
||||
# If it's ever supported, this test will need to be updated
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
b = torch.rand(3, 1, device=device)
|
||||
c = torch.rand(3, 4, device=device)
|
||||
|
||||
def func(a, b, c):
|
||||
aa = torch.digamma(a) # not supported by nvfuser
|
||||
d = torch.add(b, c)
|
||||
dd = torch.sqrt(d)
|
||||
return torch.mul(aa, dd.digamma())
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, b, c)
|
||||
|
||||
expected = execute(gm, a, b, c, executor="aten")
|
||||
actual = execute(gm, a, b, c, executor="nvfuser")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
|
||||
# This test is to ensure that nvfuser partitioned executor works correctly
|
||||
# It's assumed that digamma is not supported by nvfuser
|
||||
# If it's ever supported, this test will need to be updated
|
||||
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
|
||||
def func(a):
|
||||
return torch.digamma(a) # not supported by nvfuser
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
with catch_warnings(record=True) as w:
|
||||
# Trigger warning
|
||||
execute(gm, a, executor="nvfuser")
|
||||
# Check warning occurs
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertTrue("is not supported by nvFuser" in str(w[-1].message))
|
||||
|
||||
def test_nvprims(self, device):
|
||||
# This test is to ensure that nvfuser specific prims are exposed
|
||||
# and can be traced with make_fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
def func(a):
|
||||
return torch.ops.nvprims.add(a, a)
|
||||
|
||||
a = torch.randn(3, 4, device=device)
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
self.assertTrue(node.name == "add")
|
||||
self.assertTrue(node.target == torch.ops.nvprims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.prims.add.default)
|
||||
self.assertFalse(node.target == torch.ops.aten.add.default)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_native_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
# This test verifies that native_batch_norm is translated into nvprims
|
||||
# and can be executed with nvFuser
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_native_batch_norm,
|
||||
)
|
||||
|
||||
samples = sample_inputs_native_batch_norm(
|
||||
None, device, dtype, requires_grad=False
|
||||
)
|
||||
batch_norms = [
|
||||
torch.native_batch_norm,
|
||||
torch.ops.aten.native_batch_norm,
|
||||
torch.ops.aten.native_batch_norm.default,
|
||||
torch.ops.nvprims.native_batch_norm.default,
|
||||
]
|
||||
for sample, batch_norm in product(samples, batch_norms):
|
||||
if sample.input.numel() == 0:
|
||||
continue
|
||||
|
||||
def func(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(sample.input, *sample.args)
|
||||
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
includes_aten_batch_norm = any(
|
||||
torch.ops.aten.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_batch_norm)
|
||||
|
||||
includes_nvprims_batch_norm = any(
|
||||
torch.ops.nvprims.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_batch_norm)
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, gm(sample.input, *sample.args))
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
def test_cudnn_batch_norm_nvprims(self, device, dtype):
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
# This test verifies that cudnn_batch_norm is translated into nvprims
|
||||
# and can be executed with nvFuser
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
sample_inputs_native_batch_norm,
|
||||
)
|
||||
|
||||
samples = sample_inputs_native_batch_norm(
|
||||
None, device, dtype, requires_grad=False
|
||||
)
|
||||
for sample in samples:
|
||||
if sample.input.numel() == 0:
|
||||
continue
|
||||
|
||||
def func(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return torch.ops.aten.cudnn_batch_norm.default(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
momentum,
|
||||
eps,
|
||||
)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(sample.input, *sample.args)
|
||||
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
includes_aten_batch_norm = any(
|
||||
torch.ops.aten.cudnn_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_batch_norm)
|
||||
|
||||
includes_nvprims_batch_norm = any(
|
||||
torch.ops.nvprims.native_batch_norm.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_batch_norm)
|
||||
|
||||
# Check that the graph can be executed with nvFuser
|
||||
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
|
||||
ref_out = gm(sample.input, *sample.args)
|
||||
for idx, (left, right) in enumerate(zip(out, ref_out)):
|
||||
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
|
||||
if idx == 3:
|
||||
self.assertTrue(torch.all(torch.eq(left, 0)))
|
||||
else:
|
||||
self.assertEqual(left, right)
|
||||
|
||||
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
def test_batch_norm_backward_nvprims(self, device, dtype):
|
||||
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
|
||||
|
||||
samples_iter = sample_inputs_batch_norm(None, device, dtype, requires_grad=True)
|
||||
sample = next(samples_iter)
|
||||
grad = torch.randn_like(sample.input)
|
||||
|
||||
def func1(grad, input, weight, rm, rv, eps, train):
|
||||
return torch.ops.aten.native_batch_norm_backward.default(
|
||||
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
|
||||
)
|
||||
|
||||
def func2(grad, input, weight, rm, rv, eps, train):
|
||||
return torch.ops.aten.cudnn_batch_norm_backward.default(
|
||||
input, grad, weight, rm, rv, rm, rv, eps, grad
|
||||
)
|
||||
|
||||
args = sample.args
|
||||
kwargs = sample.kwargs
|
||||
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
|
||||
|
||||
for func in (func1, func2):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(*all_args)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_batch_norm_backward = any(
|
||||
torch.ops.aten.native_batch_norm_backward.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_batch_norm_backward)
|
||||
all_nvprims = all(
|
||||
str(node.target).startswith("nvprims") for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(all_nvprims)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_silu_backward_no_filled_tensor(self, device, dtype):
|
||||
# This test verifies a workaround for
|
||||
# https://github.com/pytorch/pytorch/issues/86612
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from functorch import functionalize
|
||||
from torch._prims.nvfuser_executor import _remove_empty_like_fill
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
def func(a):
|
||||
out = torch.nn.functional.silu(a)
|
||||
grad = torch.ones_like(out)
|
||||
return torch.autograd.grad([out], [a], [grad])
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
|
||||
a = make_arg((3, 4))
|
||||
gm = make_fx(func)(a)
|
||||
# functionalize(gm) doesn't work with non-detached inputs
|
||||
gm = make_fx(functionalize(gm))(a.detach())
|
||||
|
||||
# replace aten.sub with nvprims.sub
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(a)
|
||||
|
||||
# Check that the graph contains empty_like
|
||||
any_aten_empty_like = any(
|
||||
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(any_aten_empty_like)
|
||||
any_aten_fill = any(
|
||||
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
|
||||
)
|
||||
self.assertTrue(any_aten_fill)
|
||||
|
||||
# Now remove the empty_like and fill
|
||||
gm = _remove_empty_like_fill(gm)
|
||||
any_aten_empty_like = any(
|
||||
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(any_aten_empty_like)
|
||||
any_aten_fill = any(
|
||||
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
|
||||
)
|
||||
self.assertFalse(any_aten_fill)
|
||||
self.assertEqual(gm(a), func(a))
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
@parametrize("correction", [0, 1])
|
||||
@ -855,7 +185,7 @@ class TestPrims(TestCase):
|
||||
traced = make_traced(_wrapper)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
for executor in ('aten', 'strictly_nvfuser'):
|
||||
for executor in ('aten',):
|
||||
fn = partial(traced, executor=executor)
|
||||
shape = (5, 5)
|
||||
a = make_arg(shape)
|
||||
@ -865,160 +195,6 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(result.is_contiguous)
|
||||
self.assertEqual(_wrapper(a), result)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
@parametrize("correction", [0, 1])
|
||||
@parametrize("keepdim", [True, False])
|
||||
def test_var_mean(self, device, dtype, correction, keepdim):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
|
||||
def _wrapper(a):
|
||||
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(_wrapper)(make_arg((5, 5)))
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_nvprims_var_mean = any(
|
||||
torch.ops.nvprims.var_mean.main == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_var_mean)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((3, 4, 5))
|
||||
|
||||
def func1(a):
|
||||
return a.view(tuple(reversed(a.shape)))
|
||||
|
||||
def func2(a):
|
||||
return a.reshape(tuple(reversed(a.shape)))
|
||||
|
||||
def func3(a):
|
||||
return torch.view_copy(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func4(a):
|
||||
return torch.reshape(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func5(a):
|
||||
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func6(a):
|
||||
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
def func7(a):
|
||||
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
|
||||
|
||||
for func in (func1, func2, func3, func4, func5, func6, func7):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_nvprims_view = any(
|
||||
torch.ops.nvprims.view.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(includes_nvprims_view)
|
||||
|
||||
# Try executing the graph
|
||||
out = execute(gm, a, executor="strictly_nvfuser")
|
||||
self.assertEqual(out, func(a))
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_nvprims_view_partitioner(self, device, dtype):
|
||||
# This test verifies that views that are not fused with other ops are
|
||||
# correctly overriden to call aten implementation.
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.nvfuser_executor import maybe_partition_graph
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((4, 5))
|
||||
b = make_arg((5, 4))
|
||||
|
||||
def func(a, b):
|
||||
aa = a.view(b.shape)
|
||||
aa = aa.view(a.shape)
|
||||
return aa.digamma()
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(func)(a, b)
|
||||
gm, _ = maybe_partition_graph(gm, False, False)
|
||||
|
||||
out = gm(a, b)
|
||||
self.assertEqual(out, func(a, b))
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32, torch.float16)
|
||||
def test_cpu_tensor(self, device, dtype):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
|
||||
def _wrapper(t0, t1, cpu_scalar):
|
||||
return t0 + t1 + cpu_scalar
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((12, 1))
|
||||
b = make_arg((12, 12))
|
||||
c = torch.tensor(0.5)
|
||||
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(_wrapper)(a, b, c)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
actual = execute(gm, a, b, c, executor="nvfuser")
|
||||
# cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback
|
||||
self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
expected = execute(gm, a, b, c, executor="aten")
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_add = any(
|
||||
torch.ops.aten.add.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_add)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser")
|
||||
# cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning
|
||||
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
||||
|
||||
self.assertEqual(expected, nvprim_aten_fallback)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_pytree_input_output(self, device, dtype):
|
||||
@make_traced
|
||||
def fn(a, b_dict):
|
||||
b = b_dict["b"]
|
||||
d = {}
|
||||
d["c"] = torch.add(a, b)
|
||||
return (d, torch.add(a, d["c"]))
|
||||
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
a = make_arg((5, 5))
|
||||
b = make_arg((1, 5))
|
||||
b_dict = {"b": b}
|
||||
|
||||
result_aten = fn(a, b_dict, executor="aten")
|
||||
result_nvfuser = fn(a, b_dict, executor="strictly_nvfuser")
|
||||
self.assertEqual(result_aten, result_nvfuser)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_memory_format_strides(self, device, dtype):
|
||||
shapes = (
|
||||
@ -1227,70 +403,6 @@ instantiate_device_type_tests(TestRefs, globals())
|
||||
|
||||
|
||||
class TestDecomp(TestCase):
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_decomposition_type_promotion_nvprim_amp(self, device, dtype):
|
||||
x = torch.rand(5, device=device).to(dtype)
|
||||
y = torch.rand(5, device=device).to(dtype)
|
||||
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode, _is_func_unsupported_nvfuser
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
op = torch.ops.aten.leaky_relu_backward.default
|
||||
op_decomp = torch._decomp.decomposition_table.get(op)
|
||||
|
||||
def fn0(*arg):
|
||||
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, op_decomp, arg, {})
|
||||
|
||||
def fn1(x):
|
||||
x = x * 2
|
||||
x = x @ x
|
||||
x = x * 2
|
||||
return x
|
||||
|
||||
self.assertFalse(fn0(x, y, 0.3, False))
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
|
||||
# Autocast context has C++ level ATen calls that are hidden from
|
||||
# TorchRefsNvfuserCapabilityMode that works only on Python level.
|
||||
# The first call to make_fx records autocast C++ calls directly and
|
||||
# doesn't have the chance to translate to nvprims. After the first
|
||||
# call, "gm" contains explicit calls to torch.ops.aten and nothing
|
||||
# is hidden, so the second call to make_fx actually translates
|
||||
# recorded autocast dtype conversions to nvprims.
|
||||
with torch.autocast("cuda"):
|
||||
gm = make_fx(fn1)(x)
|
||||
gm = make_fx(gm)(x)
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
includes_aten_to_copy = any(
|
||||
torch.ops.aten._to_copy.default == node.target
|
||||
for node in call_function_nodes
|
||||
)
|
||||
self.assertFalse(includes_aten_to_copy)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float16, torch.float32)
|
||||
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
|
||||
# Test masked_fill decomposition doesn't trigger data-dependent control flow
|
||||
# on TorchRefsNvfuser speculative lowering.
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
|
||||
x = torch.empty(2, 3, device=device).to(dtype=dtype)
|
||||
mask = torch.ones_like(x).bool()
|
||||
y = torch.tensor(0.3) # cpu scalar tensor
|
||||
|
||||
def func(x, mask, y):
|
||||
return torch.masked_fill(x, mask, y)
|
||||
|
||||
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
|
||||
gm = make_fx(func, decomposition_table={})(x, mask, y)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught:
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(x, mask, y)
|
||||
# masked_fill decomposition fails inside `get_isolated_graphmodule`
|
||||
self.assertFalse(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
|
||||
|
||||
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
|
||||
def test_decomposition_method_vararg(self, device, dtype, op):
|
||||
# some ops have vararg variants for the methods. this tests it.
|
||||
|
145
third_party/nvfuser/python_tests/test_dynamo.py
vendored
145
third_party/nvfuser/python_tests/test_dynamo.py
vendored
@ -1,145 +0,0 @@
|
||||
# Owner(s): ["module: nvfuser"]
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
|
||||
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
|
||||
|
||||
|
||||
def is_pre_volta():
|
||||
if not RUN_NVFUSER:
|
||||
return False
|
||||
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
return prop.major < 7
|
||||
|
||||
|
||||
def is_networkx_available():
|
||||
try:
|
||||
import networkx # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
||||
@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows")
|
||||
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
|
||||
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
|
||||
class TestNvFuserDynamo(TestCase):
|
||||
def test_basic(self):
|
||||
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
|
||||
input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
|
||||
|
||||
@torchdynamo.optimize("nvprims_nvfuser")
|
||||
def func(a, b):
|
||||
return a.sin() + b.cos()
|
||||
|
||||
# No warnings and no errors
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
nvfuser_result = func(input1, input2)
|
||||
self.assertEqual(len(w), 0)
|
||||
eager_result = func.__wrapped__(input1, input2)
|
||||
self.assertEqual(eager_result, nvfuser_result)
|
||||
|
||||
@unittest.skipIf(not is_networkx_available(), "networkx not available")
|
||||
def test_min_cut(self):
|
||||
from functorch.compile import default_partition
|
||||
from torch._dynamo.backends.nvfuser import nvprims_fw_bw_partition_fn
|
||||
|
||||
def get_fw_bw_graph(f, inps, partitioner):
|
||||
from functorch.compile import aot_function
|
||||
|
||||
# Helper functions are taken from functorch/test_aotdispatch.py
|
||||
def extract_graph(fx_g, _, graph_cell):
|
||||
graph_cell[0] = fx_g
|
||||
return fx_g
|
||||
|
||||
fw_graph_cell = [None]
|
||||
bw_graph_cell = [None]
|
||||
aot_function(
|
||||
f,
|
||||
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
|
||||
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
|
||||
partition_fn=partitioner,
|
||||
)(*inps).sum().backward()
|
||||
return (fw_graph_cell[0], bw_graph_cell[0])
|
||||
|
||||
def get_ins_outs(fx_g):
|
||||
ins = []
|
||||
outs = []
|
||||
for n in fx_g.graph.nodes:
|
||||
if n.op == "placeholder":
|
||||
ins.append(n)
|
||||
elif n.op == "output":
|
||||
outs = tuple(n.args[0])
|
||||
return ins, outs
|
||||
|
||||
def get_num_ins_outs(fx_g):
|
||||
return tuple(len(i) for i in get_ins_outs(fx_g))
|
||||
|
||||
def func(x):
|
||||
return x * x * x
|
||||
|
||||
input1 = make_tensor(
|
||||
(3,), device="cpu", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition)
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
|
||||
|
||||
input1 = make_tensor(
|
||||
(3,), device="cpu", dtype=torch.float32, requires_grad=True
|
||||
)
|
||||
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn)
|
||||
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
|
||||
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
|
||||
|
||||
def test_batch_norm_implicit_dtype_promotion(self):
|
||||
input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32)
|
||||
input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32)
|
||||
w = make_tensor((3), device="cuda", dtype=torch.float32)
|
||||
b = make_tensor((3), device="cuda", dtype=torch.float32)
|
||||
|
||||
@torchdynamo.optimize("nvprims_nvfuser")
|
||||
def func(mat1, mat2, w, b):
|
||||
o = torch.matmul(mat1, mat2)
|
||||
return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True)
|
||||
|
||||
# No warnings and no errors
|
||||
with torch.cuda.amp.autocast():
|
||||
with warnings.catch_warnings(record=True) as warning:
|
||||
nvfuser_result = func(input1, input2, w, b)
|
||||
self.assertEqual(len(warning), 0)
|
||||
eager_result = func.__wrapped__(input1, input2, w, b)
|
||||
self.assertEqual(eager_result, nvfuser_result)
|
||||
|
||||
def test_dtype_correctness(self):
|
||||
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16)
|
||||
|
||||
@torchdynamo.optimize("nvprims_nvfuser")
|
||||
def func(a):
|
||||
tmp = a + 1.0
|
||||
# nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back
|
||||
return torch.where(tmp > 0, tmp, 0.0)
|
||||
|
||||
nvfuser_result = func(input1)
|
||||
eager_result = func.__wrapped__(input1)
|
||||
self.assertEqual(eager_result, nvfuser_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1,95 +0,0 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from ..backends.common import aot_autograd, mem_efficient_fusion_kwargs
|
||||
from .registry import register_backend, register_debug_backend
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prims_executor(gm, inputs, *, executor):
|
||||
from functorch.compile import make_boxed_func
|
||||
|
||||
# This function is called once per forward/backward pass of a graph in AOT
|
||||
# Autograd. We use it to set up the nvFuser-specific FX graph and return
|
||||
# execute function.
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch._prims.executor import execute
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
# AOT Autograd might not use the partitioner, so we need to make sure that
|
||||
# the graph is transformed to use nvFuser-compatible nodes.
|
||||
if not getattr(gm, "_nvprim_transformed", False):
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
gm = make_fx(gm)(*inputs)
|
||||
|
||||
# Then we return a callable that executes the "gm" graph
|
||||
return make_boxed_func(partial(execute, gm, executor=executor))
|
||||
|
||||
|
||||
def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
|
||||
# This function is called once per forward+backward pass of a graph in AOT
|
||||
# Autograd. We use it to set up the nvFuser-specific FX graph that is later
|
||||
# passed to the executor.
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
|
||||
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
# AOT Autograd expects arguments of the traced function to be named exactly
|
||||
# "primals, tangents"
|
||||
def func(primals, tangents):
|
||||
return joint_module(primals, tangents)
|
||||
|
||||
# First we trace the graph conditionally decomposing nodes
|
||||
# that can be sent to the nvfuser executor
|
||||
with TorchRefsNvfuserCapabilityMode():
|
||||
prim_gm = make_fx(func)(*joint_inputs)
|
||||
|
||||
# all nvprims for now
|
||||
recomputable_ops = {
|
||||
getattr(torch.ops.nvprims, prim)
|
||||
for prim in dir(torch.ops.nvprims)
|
||||
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
|
||||
and getattr(torch.ops.nvprims, prim).is_recomputable
|
||||
}
|
||||
|
||||
fw_gm, bw_gm = min_cut_rematerialization_partition(
|
||||
prim_gm,
|
||||
joint_inputs,
|
||||
recomputable_ops=recomputable_ops,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
)
|
||||
# AOT Autograd might not use the partitioner, so we need to make sure that
|
||||
# the graph is marked as already transformed to use nvFuser-compatible nodes
|
||||
fw_gm._nvprim_transformed = True
|
||||
bw_gm._nvprim_transformed = True
|
||||
return fw_gm, bw_gm
|
||||
|
||||
|
||||
def create_nvprims_backend(*, executor):
|
||||
return aot_autograd(
|
||||
fw_compiler=partial(prims_executor, executor=executor),
|
||||
bw_compiler=partial(prims_executor, executor=executor),
|
||||
partition_fn=nvprims_fw_bw_partition_fn,
|
||||
)
|
||||
|
||||
|
||||
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
|
||||
aot_nvprims_aten = create_nvprims_backend(executor="aten")
|
||||
|
||||
# "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
|
||||
# supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
|
||||
register_backend(name="nvprims_nvfuser", compiler_fn=aot_nvprims_nvfuser)
|
||||
# This is useful for debugging. Can be removed later.
|
||||
register_debug_backend(name="nvprims_aten", compiler_fn=aot_nvprims_aten)
|
||||
|
||||
|
||||
# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
|
||||
# aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
|
||||
# It uses min cut rematerialization algorithm, uses nvFuser as the
|
||||
# compiler backend, and TorchScript as the frontend.
|
||||
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
|
||||
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
|
||||
register_backend(name="aot_ts_nvfuser", compiler_fn=aot_mem_efficient_fusion)
|
@ -13,7 +13,6 @@ import torch.library
|
||||
from torch import sym_float, Tensor, TypedStorage
|
||||
from torch._C import _get_default_device
|
||||
from torch._prims.debug_prims import register_debug_prims
|
||||
from torch._prims.nvfuser_prims import register_nvprims
|
||||
from torch._prims.rng_prims import register_rng_prims
|
||||
from torch._prims_common import (
|
||||
Dim,
|
||||
@ -2964,6 +2963,5 @@ fft_c2r = _make_prim(
|
||||
doc=_fft_c2r_doc,
|
||||
)
|
||||
|
||||
register_nvprims()
|
||||
register_rng_prims()
|
||||
register_debug_prims()
|
||||
|
@ -1,7 +1,6 @@
|
||||
import functools
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,10 +12,8 @@ import torch._refs.nn
|
||||
import torch._refs.nn.functional
|
||||
import torch._refs.special
|
||||
import torch.overrides
|
||||
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
|
||||
|
||||
from torch._prims_common import torch_function_passthrough
|
||||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@ -82,55 +79,6 @@ def all_prims():
|
||||
return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
|
||||
|
||||
|
||||
class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
|
||||
"""
|
||||
Switches the interpretation of torch.ops.prims.* functions to
|
||||
use nvFuser's prims in torch.ops.nvprims.*
|
||||
|
||||
>>> # xdoctest: +SKIP("undefined vars")
|
||||
>>> with NvfuserPrimsMode():
|
||||
... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
|
||||
|
||||
By default, this context manager will fall back on the torch.ops.prims* if the
|
||||
nvprim does not exist.
|
||||
It's possible to skip certain prims by passing their names to the skip_ops
|
||||
argument. skip_ops is expected to be a sequence of strings, e.g.,
|
||||
["prims.add.default"] In order to check the expected name of a prim, one can
|
||||
use the `torch.overrides.resolve_name`.
|
||||
|
||||
>>> # xdoctest: +SKIP("undefined vars")
|
||||
>>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
|
||||
... torch.ops.prims.add.default(x, y) # does not call torch.ops.nvprims.add.default(x, y)
|
||||
"""
|
||||
|
||||
def __init__(self, *, skip_ops=()):
|
||||
self.skip_ops = skip_ops
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# If the function is in the skip list, then we don't want to
|
||||
# remap it to the nvprims.
|
||||
if torch.overrides.resolve_name(orig_func) in self.skip_ops:
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
if isinstance(orig_func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
|
||||
namespace = str(orig_func).split(".")[0]
|
||||
name = str(orig_func).split(".")[1]
|
||||
if namespace == "prims":
|
||||
nvfunc = getattr(torch.ops.nvprims, name, None)
|
||||
if nvfunc is not None:
|
||||
return nvfunc(*args, **kwargs)
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
|
||||
class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
"""
|
||||
Switches the interpretation of torch.* functions and Tensor methods to
|
||||
@ -194,235 +142,3 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
|
||||
)
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
|
||||
def _is_node_supported_nvfuser(node):
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "impl_nvfuser", None) is not None
|
||||
)
|
||||
|
||||
|
||||
def _is_func_unsupported_nvfuser(
|
||||
torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
|
||||
):
|
||||
"""
|
||||
This function traces the `func` under `torch_function_mode` and checks if
|
||||
any of the traced nodes are not supported by nvFuser. If so, we should
|
||||
fallback to the original function.
|
||||
|
||||
`skip_ops` argument is expected to be a list of strings of function names
|
||||
that would match with `torch.overrides.resolve_name`.
|
||||
|
||||
Args:
|
||||
torch_function_mode: The torch_function_mode context manager. orig_func:
|
||||
The original function, its name will be used to check if
|
||||
it should be skipped.
|
||||
func: The function to be traced. args: The args to be passed to the
|
||||
function. kwargs: The kwargs to be passed to the function.
|
||||
Keyword args:
|
||||
skip_ops: A list of ops to skip when checking if the function is
|
||||
supported.
|
||||
"""
|
||||
# One supported case is easy to check: if the resolved name of the original
|
||||
# function in the skip list, skip it.
|
||||
if torch.overrides.resolve_name(orig_func) in skip_ops:
|
||||
return True
|
||||
|
||||
with torch_function_mode:
|
||||
try:
|
||||
gm = get_isolated_graphmodule(func, args, kwargs)
|
||||
except Exception as e:
|
||||
warn(
|
||||
"get_isolated_graphmodule failed on decomposition: "
|
||||
+ func.__name__
|
||||
+ " with error message: "
|
||||
+ str(e)
|
||||
)
|
||||
# returns unsupported when tracing fails.
|
||||
return True
|
||||
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
any_unsupported = any(
|
||||
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
|
||||
)
|
||||
return any_unsupported
|
||||
|
||||
|
||||
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
|
||||
def __init__(self, *, skip_ops=()):
|
||||
aten_ops_to_skip = (
|
||||
"aten._log_softmax.default",
|
||||
"aten._log_softmax_backward_data.default",
|
||||
"aten.expand.default",
|
||||
)
|
||||
self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
|
||||
super().__init__(
|
||||
strict=False,
|
||||
should_fallback_fn=functools.partial(
|
||||
_is_func_unsupported_nvfuser,
|
||||
skip_ops=tuple(skip_ops) + aten_ops_to_skip,
|
||||
),
|
||||
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
|
||||
)
|
||||
|
||||
# TODO: remove this once version from _decomp/decompositions.py is working
|
||||
# with this context manager
|
||||
# This is a workaround for AOT Autograd graphs
|
||||
def _cudnn_batch_norm(
|
||||
self,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
):
|
||||
a, b, c = torch.ops.nvprims.native_batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
)
|
||||
if training:
|
||||
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
|
||||
return (
|
||||
a,
|
||||
weight.new_zeros((0,)),
|
||||
weight.new_zeros((0,)),
|
||||
input.new_zeros((0,), dtype=torch.uint8),
|
||||
)
|
||||
|
||||
# This is a workaround for AOT Autograd graphs
|
||||
def _cudnn_batch_norm_backward(
|
||||
self,
|
||||
input,
|
||||
grad_output,
|
||||
weight,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_var,
|
||||
epsilon,
|
||||
reserveSpace,
|
||||
):
|
||||
func = torch._decomp.decomposition_table[
|
||||
torch.ops.aten.native_batch_norm_backward.default
|
||||
]
|
||||
return func(
|
||||
grad_output,
|
||||
input,
|
||||
weight,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_var,
|
||||
True,
|
||||
epsilon,
|
||||
[True, True, True],
|
||||
)
|
||||
|
||||
def _is_var_mean(self, func):
|
||||
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
|
||||
(isinstance(func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)))
|
||||
and "aten.var_mean" in str(func)
|
||||
)
|
||||
|
||||
def _is_view_or_reshape(self, func):
|
||||
allowed_ops = {
|
||||
"torch.Tensor.view",
|
||||
"torch.Tensor.reshape",
|
||||
"torch.view_copy",
|
||||
"torch.reshape",
|
||||
"aten.view.default",
|
||||
"aten._unsafe_view.default",
|
||||
"aten.view_copy.default",
|
||||
} - set(self.skip_ops)
|
||||
return torch.overrides.resolve_name(func) in allowed_ops
|
||||
|
||||
def _is_native_batch_norm(self, func):
|
||||
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
|
||||
func == torch.ops.aten.native_batch_norm.default
|
||||
or func == torch.ops.aten.native_batch_norm
|
||||
)
|
||||
|
||||
def _is_rand_like(self, func):
|
||||
result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
|
||||
func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
|
||||
)
|
||||
return result
|
||||
|
||||
def _is_full(self, func):
|
||||
result = "torch.full" == torch.overrides.resolve_name(func) or (
|
||||
func
|
||||
in [
|
||||
torch.ops.aten.full,
|
||||
torch.ops.aten.full.names,
|
||||
]
|
||||
)
|
||||
return result
|
||||
|
||||
def __torch_function__(
|
||||
self,
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
|
||||
if self._is_var_mean(orig_func):
|
||||
return torch.ops.nvprims.var_mean(*args, **kwargs)
|
||||
|
||||
if (
|
||||
orig_func == torch.ops.aten.cudnn_batch_norm.default
|
||||
or orig_func == torch.ops.aten.cudnn_batch_norm
|
||||
):
|
||||
with self:
|
||||
return self._cudnn_batch_norm(*args, **kwargs)
|
||||
|
||||
# A workaround for AOT Autograd graphs
|
||||
# See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
|
||||
if (
|
||||
orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
|
||||
or orig_func == torch.ops.aten.cudnn_batch_norm_backward
|
||||
):
|
||||
with self:
|
||||
return self._cudnn_batch_norm_backward(*args, **kwargs)
|
||||
|
||||
if self._is_view_or_reshape(orig_func):
|
||||
a, *shape = args
|
||||
shape = torch._prims_common.extract_shape_from_varargs(
|
||||
shape, validate=False
|
||||
) # type: ignore[assignment]
|
||||
if len(kwargs) > 0:
|
||||
warn("view has ignored kwargs!")
|
||||
return torch.ops.nvprims.view(a, shape)
|
||||
|
||||
if orig_func == torch.ops.aten._reshape_alias.default:
|
||||
a, shape, stride = args
|
||||
if len(kwargs) > 0:
|
||||
warn("view has ignored kwargs!")
|
||||
return torch.ops.nvprims.view(a, shape)
|
||||
|
||||
if self._is_native_batch_norm(orig_func):
|
||||
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
|
||||
|
||||
if self._is_rand_like(orig_func):
|
||||
if len(kwargs) > 0:
|
||||
warn("rand_like has ignored kwargs!")
|
||||
return torch.ops.nvprims.rand_like(*args)
|
||||
|
||||
if self._is_full(orig_func):
|
||||
return torch.ops.nvprims.full(*args, **kwargs)
|
||||
|
||||
# Then we use TorchRefsMode to interpret the rest
|
||||
return super().__torch_function__(orig_func, types, args, kwargs)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
|
||||
from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
|
||||
from torch._prims.context import TorchRefsMode
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
|
||||
@ -21,14 +20,8 @@ def execute(
|
||||
|
||||
if executor == "aten":
|
||||
return gm.forward(*args)
|
||||
elif executor == "nvfuser":
|
||||
return nvfuser_execute_partitioned(
|
||||
gm, *args, executor_parameters=executor_parameters
|
||||
)
|
||||
elif executor == "strictly_nvfuser":
|
||||
return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
|
||||
|
||||
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten, nvfuser."
|
||||
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@ -53,16 +46,14 @@ def make_traced(fn: Callable):
|
||||
|
||||
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
||||
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
|
||||
result = traced_foo(a, b, executor='nvfuser')
|
||||
|
||||
Executor may be either 'aten' or 'nvfuser'.
|
||||
result = traced_foo(a, b, executor='aten')
|
||||
"""
|
||||
|
||||
def _traced(*args, executor="aten", **kwargs):
|
||||
# TODO: caching
|
||||
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
|
||||
|
||||
with NvfuserPrimsMode(), TorchRefsMode():
|
||||
with TorchRefsMode():
|
||||
gm = make_fx(wrapped)(all_args)
|
||||
return execute(gm, all_args, executor=executor)
|
||||
|
||||
|
@ -1,510 +0,0 @@
|
||||
import operator
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.overrides
|
||||
from torch._prims_common import (
|
||||
_torch_dtype_to_nvfuser_dtype_map,
|
||||
getnvFuserDtype,
|
||||
Number,
|
||||
number_type,
|
||||
)
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from nvfuser import ( # type: ignore[attr-defined, import]
|
||||
DataType,
|
||||
FusionDefinition,
|
||||
Tensor,
|
||||
)
|
||||
|
||||
def create_fusion_definition():
|
||||
fd = FusionDefinition()
|
||||
return fd, fd
|
||||
|
||||
except ImportError:
|
||||
from nvfuser._C import ( # type: ignore[import]
|
||||
DataType,
|
||||
Fusion,
|
||||
FusionDefinition,
|
||||
Tensor,
|
||||
)
|
||||
|
||||
def create_fusion_definition():
|
||||
fusion = Fusion()
|
||||
return fusion, FusionDefinition(fusion)
|
||||
|
||||
else:
|
||||
DataType = None
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def get_nvprim_dump_nvtx():
|
||||
return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX")
|
||||
|
||||
|
||||
DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
|
||||
{
|
||||
"use_python_fusion_cache": True,
|
||||
"allow_single_op_fusion": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
|
||||
# for cached construction of the nvFuser's Fusion
|
||||
# TODO: change what is stored in the cache for nvFuser's Tensor objects
|
||||
# https://github.com/pytorch/pytorch/issues/80551
|
||||
@dataclass(frozen=True)
|
||||
class nvFuserTensorTemplate:
|
||||
symbolic_shape: tuple
|
||||
contiguity: tuple
|
||||
dtype: DataType
|
||||
is_cpu: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class nvFuserScalarTemplate:
|
||||
dtype: DataType
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def compute_symbolic_shape(shape):
|
||||
"""Computes the symbolic shape of a tensor.
|
||||
nvFuser specializes on size-1 dimensions as broadcasted dimensions.
|
||||
-1 is used to represent any size."""
|
||||
return tuple(1 if s == 1 else -1 for s in shape)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def compute_contiguity(shape, strides):
|
||||
"""Computes the contiguity information to simplify internal indexing.
|
||||
Contiguous dimensions are represented by True, strided dimensions
|
||||
are represented by False.
|
||||
"""
|
||||
try:
|
||||
from nvfuser import compute_contiguity # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from nvfuser._C import compute_contiguity
|
||||
|
||||
return tuple(compute_contiguity(shape, strides))
|
||||
|
||||
|
||||
def to_nvfuser_template_args(args):
|
||||
def to_nvfuser(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
return nvFuserTensorTemplate(
|
||||
compute_symbolic_shape(arg.size()),
|
||||
compute_contiguity(arg.size(), arg.stride()),
|
||||
getnvFuserDtype(arg.dtype),
|
||||
arg.is_cpu, # type: ignore[attr-defined]
|
||||
)
|
||||
elif isinstance(arg, Number):
|
||||
return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
|
||||
else:
|
||||
return arg
|
||||
|
||||
return tree_map(to_nvfuser, args)
|
||||
|
||||
|
||||
def _any_get_attr_used(call_function_nodes):
|
||||
return any(
|
||||
filter(
|
||||
# bug in mypy https://github.com/python/mypy/issues/12682
|
||||
lambda n: any( # type: ignore[arg-type]
|
||||
a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined]
|
||||
),
|
||||
call_function_nodes,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# MyPy bug: https://github.com/python/mypy/issues/5107
|
||||
@lru_cache(maxsize=1024) # type: ignore[arg-type]
|
||||
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"Attempting to use nvFuser trace executor but CUDA is not available!"
|
||||
)
|
||||
|
||||
# Everything in the graph must support nvfuser
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
continue
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and getattr(node.target, "impl_nvfuser", None) is None
|
||||
):
|
||||
raise ValueError(
|
||||
"All call_function nodes in the graph must support nvfuser. "
|
||||
f"Node {node} with target {node.target} does not support nvfuser"
|
||||
)
|
||||
|
||||
graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
assert len(graph_input_nodes) == len(
|
||||
nv_args_templates
|
||||
), "Number of placeholder nodes in the graph must match number of args"
|
||||
assert len(nv_args_templates) > 0, "There must be at least one argument"
|
||||
assert (
|
||||
len(call_function_nodes) > 0
|
||||
), "Graph must contain at least one call_function node"
|
||||
assert not _any_get_attr_used(
|
||||
call_function_nodes
|
||||
), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
|
||||
|
||||
# Checking output dtypes
|
||||
output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
|
||||
orig_flat_out, _ = tree_flatten(output_node.args[0])
|
||||
|
||||
fusion, fd = create_fusion_definition()
|
||||
with fd:
|
||||
|
||||
def _to_nvfuser_constant(arg):
|
||||
if isinstance(arg, Number):
|
||||
return fd.define_constant(arg)
|
||||
else:
|
||||
return arg
|
||||
|
||||
class FusionInterpreter(torch.fx.Interpreter):
|
||||
def run_node(self, node):
|
||||
# Squeeze requires original shape of args[0]
|
||||
if node.target in (
|
||||
torch.ops.nvprims.squeeze,
|
||||
torch.ops.nvprims.squeeze.default,
|
||||
):
|
||||
original_shape = list(node.args[0].meta["tensor_meta"].shape)
|
||||
assert len(node.args) == 2
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
args = args[:1] + (original_shape,) + args[1:]
|
||||
return self.call_function(node.target, args, node.kwargs)
|
||||
|
||||
if node.target in (
|
||||
torch.ops.nvprims.native_batch_norm,
|
||||
torch.ops.nvprims.native_batch_norm.default,
|
||||
):
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(node)
|
||||
assert len(args) == 8
|
||||
training = args[5]
|
||||
args6_end = tuple(_to_nvfuser_constant(arg) for arg in args[6:])
|
||||
args = args[:5] + (training,) + args6_end
|
||||
return node.target.impl_nvfuser(fd, *args, **kwargs)
|
||||
|
||||
return super().run_node(node)
|
||||
|
||||
def call_function(self, target, args, kwargs):
|
||||
# This handles tuple unpacking
|
||||
if target == operator.getitem:
|
||||
assert isinstance(args[0], tuple)
|
||||
return target(*args, **kwargs)
|
||||
args = tuple(_to_nvfuser_constant(arg) for arg in args)
|
||||
target = target.impl_nvfuser
|
||||
args = (fd,) + args
|
||||
return target(*args, **kwargs)
|
||||
|
||||
def output(self, target, args, kwargs):
|
||||
flat_out, unflatten_spec = tree_flatten(args[0])
|
||||
for o, orig_o in zip(flat_out, orig_flat_out):
|
||||
# casting outputs to the original data type
|
||||
# ensures outputs produced by fusion would always agree with original GraphModule
|
||||
out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr]
|
||||
assert isinstance(
|
||||
o, Tensor
|
||||
), "output from codegen has to be tensor type"
|
||||
fd.add_output(fd.ops.cast(o, dtype=out_dtype))
|
||||
return args[0]
|
||||
|
||||
def templates_to_nvfuser_inputs(arg):
|
||||
if isinstance(arg, nvFuserTensorTemplate):
|
||||
x = fd.define_tensor(
|
||||
arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
|
||||
)
|
||||
return x
|
||||
elif isinstance(arg, nvFuserScalarTemplate):
|
||||
x = fd.define_scalar(arg.dtype)
|
||||
return x
|
||||
else:
|
||||
return arg
|
||||
|
||||
# Transforms graph to call nvfuser lowerings
|
||||
nv_args = tuple(
|
||||
templates_to_nvfuser_inputs(nv_arg) for nv_arg in nv_args_templates
|
||||
)
|
||||
out = FusionInterpreter(gm).run(*nv_args)
|
||||
flat_out, unflatten_spec = tree_flatten(out)
|
||||
|
||||
return fusion, unflatten_spec
|
||||
|
||||
|
||||
def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
|
||||
executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
|
||||
flat_args, _ = tree_flatten(args)
|
||||
|
||||
# check for cuda only fusion
|
||||
if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
|
||||
(
|
||||
not isinstance(arg, torch.Tensor)
|
||||
or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
|
||||
or arg.is_cuda # type: ignore[attr-defined]
|
||||
)
|
||||
for arg in flat_args
|
||||
):
|
||||
# Construction of the fusion is expensive and cached based on the GraphModule
|
||||
# and symbolic nvFuser args.
|
||||
nv_template_args = to_nvfuser_template_args(flat_args)
|
||||
use_cache = executor_parameters.get(
|
||||
"use_python_fusion_cache",
|
||||
DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
|
||||
)
|
||||
if use_cache:
|
||||
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
|
||||
else:
|
||||
fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args) # type: ignore[misc]
|
||||
|
||||
# Inputs to fusion.execute correspond to the same template/symbolic inputs
|
||||
# marked with `define_tensor/scalar`
|
||||
concrete_fusion_inputs = tuple(
|
||||
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
|
||||
)
|
||||
|
||||
if get_nvprim_dump_nvtx():
|
||||
torch.cuda.nvtx.range_push(
|
||||
"fusion: {}, graph: {}".format(
|
||||
fusion.id(),
|
||||
str(
|
||||
[
|
||||
{
|
||||
"op": n.op,
|
||||
"name": n.name,
|
||||
"args": n.args,
|
||||
"kwargs": n.kwargs,
|
||||
}
|
||||
for n in gm.graph.nodes
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
warn("nvfuser integration in primTorch is deprecated")
|
||||
result = tree_unflatten(
|
||||
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
|
||||
unflatten_spec, # type: ignore[has-type]
|
||||
)
|
||||
if get_nvprim_dump_nvtx():
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return result
|
||||
else:
|
||||
warn(
|
||||
"nvfuser_executor is executed with non-cuda args, fallback to aten executor"
|
||||
)
|
||||
return gm.forward(*args)
|
||||
|
||||
|
||||
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
|
||||
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
|
||||
# special case to stop lowering to nvprim when converting to an unsupported type
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.nvprims.convert_element_type.default
|
||||
):
|
||||
return (
|
||||
_torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
|
||||
and _torch_dtype_to_nvfuser_dtype_map.get(
|
||||
node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
|
||||
)
|
||||
is not None
|
||||
)
|
||||
return node.op == "call_function" and (
|
||||
getattr(node.target, "impl_nvfuser", None) is not None
|
||||
or node.target == operator.getitem
|
||||
)
|
||||
|
||||
|
||||
class PartitionedInterpreter(torch.fx.Interpreter):
|
||||
def call_module(self, target, args, kwargs):
|
||||
assert isinstance(target, str)
|
||||
assert len(kwargs) == 0
|
||||
submod = self.fetch_attr(target)
|
||||
# CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
|
||||
if target.startswith("fused_"):
|
||||
return nvfuser_execute(submod, *args)
|
||||
else:
|
||||
return super().call_module(target, args, kwargs)
|
||||
|
||||
|
||||
class NvfuserGraphModule(torch.nn.Module):
|
||||
def __init__(self, gm, use_python_fusion_cache):
|
||||
super().__init__()
|
||||
self.gm = gm
|
||||
self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}
|
||||
|
||||
def __call__(self, *args):
|
||||
return nvfuser_execute(
|
||||
self.gm, *args, executor_parameters=self.executor_parameters
|
||||
)
|
||||
|
||||
|
||||
# A set of operators that are supported by nvFuser
|
||||
# but should not form a fusion group solely on their own
|
||||
_non_compute_ops = [
|
||||
"torch.ops." + str(getattr(torch.ops.nvprims, prim).default)
|
||||
for prim in dir(torch.ops.nvprims)
|
||||
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
|
||||
and getattr(torch.ops.nvprims, prim).return_type
|
||||
== torch._prims_common.RETURN_TYPE.VIEW
|
||||
]
|
||||
|
||||
_allowed_single_node_partition_ops = [
|
||||
"torch.ops.nvprims.native_batch_norm.default",
|
||||
"torch.ops.nvprims.var_mean.default",
|
||||
"torch.ops.nvprims.var_mean.main",
|
||||
]
|
||||
|
||||
|
||||
def _remove_empty_like_fill(gm: GraphModule):
|
||||
# Remove empty_like + fill nodes that prevent lowering to nvprims
|
||||
# This is a workaround for nonoptimal traces of C++ code `(1 - tensor)`
|
||||
# https://github.com/pytorch/pytorch/issues/86612
|
||||
|
||||
def pattern(scalar, tensor):
|
||||
# pattern for C++ trace of `scalar - tensor`. We are looking for the
|
||||
# pattern of aten and nvprims.sub specifically because we want to remove
|
||||
# the empty_like + fill nodes after lowering of AOT Autograd trace to
|
||||
# nvprims In the future, nvFuser might support fill, and empty_like and
|
||||
# this workaround can be removed.
|
||||
empty_like = torch.ops.aten.empty_like.default(
|
||||
tensor, memory_format=torch.preserve_format
|
||||
)
|
||||
fill = torch.ops.aten.fill.Scalar(empty_like, scalar)
|
||||
sub = torch.ops.nvprims.sub.default(fill, tensor)
|
||||
return sub
|
||||
|
||||
def replacement(scalar, tensor):
|
||||
return torch.ops.nvprims.sub.default(scalar, tensor)
|
||||
|
||||
torch.fx.replace_pattern(gm, pattern, replacement)
|
||||
return gm
|
||||
|
||||
|
||||
# MyPy bug: https://github.com/python/mypy/issues/5107
|
||||
@lru_cache(maxsize=1024) # type: ignore[arg-type]
|
||||
def maybe_partition_graph(
|
||||
gm: GraphModule, allow_single_op_fusion: bool, use_python_fusion_cache: bool
|
||||
):
|
||||
gm = _remove_empty_like_fill(gm)
|
||||
supported_ops = NvfuserPrimOperatorSupport()
|
||||
call_function_nodes = list(
|
||||
filter(lambda n: n.op == "call_function", gm.graph.nodes)
|
||||
)
|
||||
# the graph is partitioned only if at least one node is not supported by nvFuser
|
||||
any_unsupported = any(
|
||||
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
|
||||
)
|
||||
any_unsupported |= len(call_function_nodes) == 0
|
||||
|
||||
# When there are constant tensors in the graph, we can't partition it
|
||||
# because deepcopy fails. Here we just return the original graph to be
|
||||
# executed by eager mode
|
||||
# https://github.com/pytorch/pytorch/issues/84415
|
||||
if (
|
||||
_any_get_attr_used(call_function_nodes)
|
||||
or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
|
||||
):
|
||||
return gm, True
|
||||
|
||||
if any_unsupported:
|
||||
# CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
|
||||
gm = deepcopy(gm)
|
||||
partitioner = CapabilityBasedPartitioner(
|
||||
gm,
|
||||
supported_ops,
|
||||
allows_single_node_partition=allow_single_op_fusion,
|
||||
non_compute_ops=_non_compute_ops,
|
||||
allowed_single_node_partition_ops=_allowed_single_node_partition_ops,
|
||||
)
|
||||
partitions = partitioner.propose_partitions()
|
||||
partitioner.remove_bookend_non_compute_ops(partitions)
|
||||
if len(partitions) == 0:
|
||||
warn(
|
||||
"No partition found for the graph. "
|
||||
+ "This is likely because the graph is not supported by nvFuser. "
|
||||
+ "Please use the eager ATen mode to execute the graph.",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
partitioned_graph = partitioner.fuse_partitions(partitions)
|
||||
|
||||
# Replacing graph's fused submodules with a wrapper module with
|
||||
# __call__() method that calls nvfuser_execute.
|
||||
# This avoids the need to call the interpreter on the graph
|
||||
for node in partitioned_graph.graph.nodes:
|
||||
# TODO: use a better way to identify fused submodule
|
||||
if node.op == "call_module" and "fused_" in node.name:
|
||||
nvfuser_submodule = getattr(partitioned_graph, node.name)
|
||||
partitioned_graph.delete_submodule(node.target)
|
||||
gm.add_submodule(
|
||||
node.target,
|
||||
NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache),
|
||||
)
|
||||
|
||||
# Go through the graph and replace all the nodes that were converted to
|
||||
# nvprims but won't be sent to nvFuser with a call to PyTorch's eager
|
||||
# mode. This is necessary because torch.ops.* have higher overhead than
|
||||
# calling the eager mode directly.
|
||||
for node in partitioned_graph.graph.nodes:
|
||||
if node.op == "call_function" and str(node.target).startswith("nvprims."):
|
||||
if getattr(node.target, "impl_aten", None) is not None:
|
||||
node.target = node.target.impl_aten
|
||||
partitioned_graph.graph.eliminate_dead_code()
|
||||
partitioned_graph.recompile()
|
||||
return partitioned_graph, any_unsupported
|
||||
else:
|
||||
return gm, any_unsupported
|
||||
|
||||
|
||||
class NVTXInterpreter(torch.fx.Interpreter):
|
||||
def run_node(self, n):
|
||||
torch.cuda.nvtx.range_push(
|
||||
f"name: {n.name}, args: {n.args}, op: {n.op}, kwargs: {n.kwargs}"
|
||||
)
|
||||
result = super().run_node(n)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
return result
|
||||
|
||||
|
||||
def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None):
|
||||
executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
|
||||
# maybe_partition_graph function is cached so we can't use non-hashable arguments
|
||||
allow_single_op_fusion = executor_parameters.get(
|
||||
"allow_single_op_fusion",
|
||||
DEFAULT_NVFUSER_PYTHON_CONFIG["allow_single_op_fusion"],
|
||||
)
|
||||
use_python_fusion_cache = executor_parameters.get(
|
||||
"use_python_fusion_cache",
|
||||
DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
|
||||
)
|
||||
# When possible it's better to use nvfuser_execute directly
|
||||
# because it avoids GraphModule's overhead
|
||||
gm, is_partitioned = maybe_partition_graph(
|
||||
gm,
|
||||
allow_single_op_fusion=allow_single_op_fusion,
|
||||
use_python_fusion_cache=use_python_fusion_cache,
|
||||
)
|
||||
if is_partitioned:
|
||||
if get_nvprim_dump_nvtx():
|
||||
return NVTXInterpreter(gm).run(*args)
|
||||
else:
|
||||
return gm(*args)
|
||||
else:
|
||||
return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
|
@ -1,840 +0,0 @@
|
||||
# Module for defining "primitive" operations executable by the nvFuser. This
|
||||
# list exists to decouple main set of primitives from the ones that provide a
|
||||
# lowering of the op to nvFuser’s Python interface. Mostly torch.ops.nvprims is
|
||||
# a subset of the primitives in torch.ops.prims, but some additional primitives
|
||||
# can be added in the future for the corresponding higher-level torch/aten
|
||||
# functions.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
|
||||
from torch._prims_common import (
|
||||
DimsSequenceType,
|
||||
elementwise_dtypes,
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND,
|
||||
getnvFuserDtype,
|
||||
make_contiguous_strides_for,
|
||||
NumberType,
|
||||
ShapeType,
|
||||
TensorLikeType,
|
||||
)
|
||||
|
||||
from torch._prims_common.wrappers import (
|
||||
_maybe_convert_to_dtype,
|
||||
backwards_not_supported,
|
||||
elementwise_type_promotion_wrapper,
|
||||
)
|
||||
|
||||
nvprim_namespace = "nvprims"
|
||||
nvprim = torch.library.Library(nvprim_namespace, "DEF")
|
||||
nvprim_impl = torch.library.Library(
|
||||
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
|
||||
)
|
||||
nvprim_implicit_impl = torch.library.Library(
|
||||
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
|
||||
)
|
||||
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
|
||||
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
|
||||
|
||||
nvprim_names = [
|
||||
"abs",
|
||||
"acos",
|
||||
"asin",
|
||||
"atan",
|
||||
"atanh",
|
||||
"cos",
|
||||
"cosh",
|
||||
"clone",
|
||||
"bitwise_not",
|
||||
"ceil",
|
||||
"erf",
|
||||
"erfc",
|
||||
"exp",
|
||||
"expm1",
|
||||
"floor",
|
||||
"imag",
|
||||
"isfinite",
|
||||
"lgamma",
|
||||
"log",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log10",
|
||||
"real",
|
||||
"reciprocal",
|
||||
"neg",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"sign",
|
||||
"sin",
|
||||
"sinh",
|
||||
"sqrt",
|
||||
"tan",
|
||||
"tanh",
|
||||
"transpose",
|
||||
"trunc",
|
||||
"add",
|
||||
"atan2",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"div",
|
||||
"eq",
|
||||
"fmod",
|
||||
"ge",
|
||||
"gt",
|
||||
"le",
|
||||
"lt",
|
||||
"mul",
|
||||
"ne",
|
||||
"pow",
|
||||
"remainder",
|
||||
"sub",
|
||||
"squeeze",
|
||||
"view_of",
|
||||
"broadcast_in_dim",
|
||||
"where",
|
||||
"convert_element_type",
|
||||
"sum",
|
||||
"var",
|
||||
"amax",
|
||||
"amin",
|
||||
]
|
||||
|
||||
_nvfuser_impls: Dict[str, Any] = {}
|
||||
|
||||
_nvfuser_unary_ops = {
|
||||
"abs",
|
||||
"acos",
|
||||
"asin",
|
||||
"atan",
|
||||
"atanh",
|
||||
"cos",
|
||||
"cosh",
|
||||
"bitwise_not",
|
||||
"ceil",
|
||||
"erf",
|
||||
"erfc",
|
||||
"exp",
|
||||
"expm1",
|
||||
"floor",
|
||||
"imag",
|
||||
"isfinite",
|
||||
"lgamma",
|
||||
"log",
|
||||
"log1p",
|
||||
"log2",
|
||||
"log10",
|
||||
"reciprocal",
|
||||
"neg",
|
||||
"real",
|
||||
"round",
|
||||
"rsqrt",
|
||||
"sign",
|
||||
"sin",
|
||||
"sinh",
|
||||
"sqrt",
|
||||
"tan",
|
||||
"tanh",
|
||||
"trunc",
|
||||
}
|
||||
|
||||
|
||||
def _assert_nvfuser_op_exists(fname: str):
|
||||
try:
|
||||
try:
|
||||
from nvfuser import ( # type: ignore[import, attr-defined]
|
||||
FusionDefinition as fd,
|
||||
)
|
||||
except ImportError:
|
||||
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
|
||||
|
||||
assert getattr(fd.Operators, fname)
|
||||
except ImportError:
|
||||
# Not all PyTorch builds have nvfuser
|
||||
pass
|
||||
|
||||
|
||||
for fname in _nvfuser_unary_ops:
|
||||
exec(
|
||||
f"""
|
||||
# Ensure that the nvfuser implementation exists
|
||||
_assert_nvfuser_op_exists("{fname}")
|
||||
|
||||
def _{fname}_nvfuser(fd, a):
|
||||
return fd.ops.{fname}(a) # type: ignore[attr-defined]
|
||||
|
||||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||||
"""
|
||||
)
|
||||
|
||||
_nvfuser_binary_ops = {
|
||||
"add",
|
||||
"atan2",
|
||||
"bitwise_and",
|
||||
"bitwise_or",
|
||||
"bitwise_xor",
|
||||
"div",
|
||||
"eq",
|
||||
"fmod",
|
||||
"ge",
|
||||
"gt",
|
||||
"le",
|
||||
"lt",
|
||||
"mul",
|
||||
"ne",
|
||||
"pow",
|
||||
"remainder",
|
||||
"sub",
|
||||
}
|
||||
|
||||
for fname in _nvfuser_binary_ops:
|
||||
exec(
|
||||
f"""
|
||||
# Ensure that the nvfuser implementation exists
|
||||
_assert_nvfuser_op_exists("{fname}")
|
||||
|
||||
def _{fname}_nvfuser(fd, a, b):
|
||||
return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
|
||||
|
||||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||||
"""
|
||||
)
|
||||
|
||||
_nvfuser_ternary_ops = {
|
||||
"where",
|
||||
}
|
||||
|
||||
for fname in _nvfuser_ternary_ops:
|
||||
exec(
|
||||
f"""
|
||||
# Ensure that the nvfuser implementation exists
|
||||
_assert_nvfuser_op_exists("{fname}")
|
||||
|
||||
def _{fname}_nvfuser(fd, a, b, c):
|
||||
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
|
||||
|
||||
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _native_batch_norm_nvfuser(
|
||||
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
"""
|
||||
if weight is None:
|
||||
weight = fd.define_null_tensor()
|
||||
if bias is None:
|
||||
bias = fd.define_null_tensor()
|
||||
if running_mean is None:
|
||||
running_mean = fd.define_null_tensor()
|
||||
if running_var is None:
|
||||
running_var = fd.define_null_tensor()
|
||||
"""
|
||||
return fd.ops.batch_norm(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
running_mean,
|
||||
running_var,
|
||||
momentum,
|
||||
eps,
|
||||
training,
|
||||
)
|
||||
|
||||
|
||||
def _broadcast_in_dim_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
shape: ShapeType,
|
||||
broadcast_dimensions: ShapeType,
|
||||
):
|
||||
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
|
||||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||||
return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _transpose_nvfuser(fd, a, dims):
|
||||
return fd.ops.permute(a, dims) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
|
||||
for idx in sorted(dimensions, reverse=True):
|
||||
a = fd.ops.squeeze(a, a_shape, idx)
|
||||
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
|
||||
return a
|
||||
|
||||
|
||||
def _view_of_nvfuser(fd, a):
|
||||
return fd.ops.set(a)
|
||||
|
||||
|
||||
def _view_nvfuser(
|
||||
fd,
|
||||
a,
|
||||
a_shape,
|
||||
new_shape,
|
||||
):
|
||||
try:
|
||||
return fd.ops.view(a, a_shape, new_shape)
|
||||
except AttributeError:
|
||||
return fd.ops.reshape(a, a_shape, new_shape)
|
||||
|
||||
|
||||
def _sum_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
):
|
||||
keep_dims = False
|
||||
try:
|
||||
from nvfuser import DataType # type: ignore[import, attr-defined]
|
||||
except ImportError:
|
||||
from nvfuser._C import DataType # type: ignore[import]
|
||||
|
||||
output_dtype = DataType.Null
|
||||
return fd.ops.sum(a, dims, keep_dims, output_dtype)
|
||||
|
||||
|
||||
def _var_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
*,
|
||||
correction: float,
|
||||
):
|
||||
keep_dims = False
|
||||
return fd.ops.var(a, dims, correction, keep_dims)
|
||||
|
||||
|
||||
def _var_mean_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
unbiased: Optional[bool] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
correction: float,
|
||||
):
|
||||
# Unbiased arg shouldn't be set when this function is called
|
||||
assert unbiased is None
|
||||
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
|
||||
# keepdim is handled by the reference implementation
|
||||
keepdim = False
|
||||
return fd.ops.var_mean(a, dims, correction, keepdim)
|
||||
|
||||
|
||||
def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
|
||||
return fd.ops.rand_like(a)
|
||||
|
||||
|
||||
def _amax_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
):
|
||||
keep_dims = False
|
||||
return fd.ops.max(a, dims, keep_dims)
|
||||
|
||||
|
||||
def _amin_nvfuser(
|
||||
fd: Any,
|
||||
a: TensorLikeType,
|
||||
dims: DimsSequenceType,
|
||||
):
|
||||
keep_dims = False
|
||||
return fd.ops.min(a, dims, keep_dims)
|
||||
|
||||
|
||||
def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
|
||||
return fd.ops.set(input)
|
||||
|
||||
|
||||
def _full_nvfuser(
|
||||
fd: Any,
|
||||
shape: ShapeType,
|
||||
fill_value: NumberType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
layout: Optional[torch.layout] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
):
|
||||
assert device != torch.device("cpu")
|
||||
assert layout is None or layout is torch.strided
|
||||
assert pin_memory is False
|
||||
assert requires_grad is False
|
||||
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
|
||||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||||
return fd.ops.full(shape, fill_value, nvfuser_dtype)
|
||||
|
||||
|
||||
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
|
||||
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
|
||||
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
|
||||
_nvfuser_impls["clone"] = _clone_nvfuser
|
||||
_nvfuser_impls["transpose"] = _transpose_nvfuser
|
||||
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
|
||||
_nvfuser_impls["view_of"] = _view_of_nvfuser
|
||||
_nvfuser_impls["view"] = _view_nvfuser
|
||||
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
|
||||
_nvfuser_impls["sum"] = _sum_nvfuser
|
||||
_nvfuser_impls["var"] = _var_nvfuser
|
||||
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
|
||||
_nvfuser_impls["amax"] = _amax_nvfuser
|
||||
_nvfuser_impls["amin"] = _amin_nvfuser
|
||||
_nvfuser_impls["full"] = _full_nvfuser
|
||||
|
||||
|
||||
def register_full():
|
||||
name = "full"
|
||||
|
||||
nvprim.define(
|
||||
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
|
||||
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
|
||||
)
|
||||
|
||||
def _meta_impl(
|
||||
size,
|
||||
fill_value,
|
||||
*,
|
||||
out=None,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=False,
|
||||
requires_grad=False,
|
||||
):
|
||||
strides = make_contiguous_strides_for(size)
|
||||
return torch._prims.TensorMeta(
|
||||
None,
|
||||
shape=size,
|
||||
strides=strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _prim_impl(
|
||||
size,
|
||||
fill_value,
|
||||
*,
|
||||
out=None,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=False,
|
||||
requires_grad=False,
|
||||
):
|
||||
return torch.full(
|
||||
size,
|
||||
fill_value,
|
||||
out=out,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_meta_impl.impl(name, _meta_impl)
|
||||
|
||||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||||
prim = prim_packet.default
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Create a tensor with given size and filled with value"
|
||||
p.impl_nvfuser = _nvfuser_impls["full"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["full"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# functorch.compile.min_cut_rematerialization_partition accepts a list of
|
||||
# operators that can be recomputed in the backward pass. This list is used to
|
||||
# determine which operators can be recomputed. If an operator is not in this
|
||||
# list, it will not be recomputed.
|
||||
_nvfuser_is_recomputable: Dict[str, bool] = {
|
||||
# Reductions are not allowed to be recomputed
|
||||
"amax": False,
|
||||
"amin": False,
|
||||
"sum": False,
|
||||
"var": False,
|
||||
"var_mean": False,
|
||||
# Normalizations are not allowed to be recomputed
|
||||
"native_batch_norm": False,
|
||||
# Random ops are not allowed to be recomputed
|
||||
"rand_like": False,
|
||||
# Everything else is allowed to be recomputed
|
||||
"abs": True,
|
||||
"acos": True,
|
||||
"add": True,
|
||||
"asin": True,
|
||||
"atan": True,
|
||||
"atan2": True,
|
||||
"atanh": True,
|
||||
"bitwise_and": True,
|
||||
"bitwise_not": True,
|
||||
"bitwise_or": True,
|
||||
"bitwise_xor": True,
|
||||
"broadcast_in_dim": True,
|
||||
"ceil": True,
|
||||
"clone": True,
|
||||
"convert_element_type": True,
|
||||
"cos": True,
|
||||
"cosh": True,
|
||||
"div": True,
|
||||
"eq": True,
|
||||
"erf": True,
|
||||
"erfc": True,
|
||||
"exp": True,
|
||||
"expm1": True,
|
||||
"floor": True,
|
||||
"fmod": True,
|
||||
"full": True,
|
||||
"ge": True,
|
||||
"gt": True,
|
||||
"imag": True,
|
||||
"isfinite": True,
|
||||
"le": True,
|
||||
"lgamma": True,
|
||||
"log": True,
|
||||
"log10": True,
|
||||
"log1p": True,
|
||||
"log2": True,
|
||||
"lt": True,
|
||||
"mul": True,
|
||||
"ne": True,
|
||||
"neg": True,
|
||||
"pow": True,
|
||||
"real": True,
|
||||
"reciprocal": True,
|
||||
"remainder": True,
|
||||
"round": True,
|
||||
"rsqrt": True,
|
||||
"sign": True,
|
||||
"sin": True,
|
||||
"sinh": True,
|
||||
"sqrt": True,
|
||||
"squeeze": True,
|
||||
"sub": True,
|
||||
"tan": True,
|
||||
"tanh": True,
|
||||
"transpose": True,
|
||||
"trunc": True,
|
||||
"view": True,
|
||||
"view_of": True,
|
||||
"where": True,
|
||||
}
|
||||
|
||||
|
||||
def register_native_batch_norm():
|
||||
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
|
||||
name = "native_batch_norm"
|
||||
|
||||
nvprim.define(
|
||||
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
|
||||
+ "bool training, float momentum, float eps)"
|
||||
+ " -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
def _prim_impl(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
):
|
||||
return torch.native_batch_norm(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
prim_packet = torch._ops.ops.nvprims.native_batch_norm
|
||||
prim = prim_packet.default
|
||||
|
||||
def _native_batch_norm_ref(
|
||||
input: torch.Tensor,
|
||||
weight: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
running_mean: Optional[torch.Tensor],
|
||||
running_var: Optional[torch.Tensor],
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if torch._prims_common.is_complex_dtype(input.dtype):
|
||||
raise NotImplementedError("Complex tensors are not supported")
|
||||
|
||||
# note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
|
||||
result_dtype = input.dtype
|
||||
computation_dtype, _ = elementwise_dtypes(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
|
||||
)
|
||||
|
||||
input_ = _maybe_convert_to_dtype(input, computation_dtype)
|
||||
output, mean, rstd = prim(
|
||||
input_, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
)
|
||||
output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
|
||||
return (output_, mean, rstd) # type: ignore[return-value]
|
||||
|
||||
def _native_batch_norm_autograd(
|
||||
input: torch.Tensor,
|
||||
weight: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
running_mean: Optional[torch.Tensor],
|
||||
running_var: Optional[torch.Tensor],
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# This wrapper is needed to convert prims calls inside
|
||||
# _native_batch_norm_ref to nvprims calls
|
||||
from torch._prims.context import NvfuserPrimsMode
|
||||
|
||||
with NvfuserPrimsMode():
|
||||
return backwards_not_supported(_native_batch_norm_ref)(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps
|
||||
)
|
||||
|
||||
nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Computes batch normalization."
|
||||
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_rand_like():
|
||||
name = "rand_like"
|
||||
|
||||
nvprim.define(
|
||||
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
|
||||
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
|
||||
)
|
||||
|
||||
def _meta_rand_like(
|
||||
self,
|
||||
*,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
memory_format=None,
|
||||
):
|
||||
strides = make_contiguous_strides_for(self.shape)
|
||||
return torch._prims.TensorMeta(
|
||||
self,
|
||||
shape=self.shape,
|
||||
strides=strides,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _prim_impl(
|
||||
self,
|
||||
*,
|
||||
dtype=None,
|
||||
layout=None,
|
||||
device=None,
|
||||
pin_memory=None,
|
||||
memory_format=None,
|
||||
):
|
||||
return torch.rand_like(
|
||||
self,
|
||||
dtype=dtype,
|
||||
layout=layout,
|
||||
device=device,
|
||||
pin_memory=pin_memory,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_meta_impl.impl(name, _meta_rand_like)
|
||||
|
||||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||||
prim = prim_packet.default
|
||||
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Computes rand_like"
|
||||
p.impl_nvfuser = _nvfuser_impls["rand_like"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_var_mean():
|
||||
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
|
||||
name = "var_mean.main"
|
||||
|
||||
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
|
||||
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
|
||||
|
||||
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
|
||||
nvprim.define(
|
||||
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, float? correction=None)"
|
||||
+ " -> (Tensor, Tensor)"
|
||||
)
|
||||
|
||||
# This function is used for device="meta" Tensors.
|
||||
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
if torch._prims_common.is_complex_dtype(inp.dtype):
|
||||
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
|
||||
else:
|
||||
output_dtype = inp.dtype
|
||||
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
|
||||
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
|
||||
if keepdim:
|
||||
output_shape = [
|
||||
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
|
||||
]
|
||||
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
|
||||
var = torch._ops.ops.nvprims.broadcast_in_dim(
|
||||
var, output_shape, broadcast_dims
|
||||
)
|
||||
mean = torch._ops.ops.nvprims.broadcast_in_dim(
|
||||
mean, output_shape, broadcast_dims
|
||||
)
|
||||
return (var, mean)
|
||||
|
||||
# This function is used under _AutoDispatchBelowAutograd context
|
||||
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||||
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
nvprim_meta_impl.impl(name, _meta_var_mean)
|
||||
|
||||
prim_packet = torch._ops.ops.nvprims.var_mean
|
||||
prim = prim_packet.main
|
||||
|
||||
def _unbiased_overload_impl(inp, unbiased):
|
||||
return prim(inp, dim=None, unbiased=unbiased)
|
||||
|
||||
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
|
||||
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
|
||||
)
|
||||
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
|
||||
correction = torch._prims_common.set_correction(unbiased, correction)
|
||||
# reduces over all dimensions if dim=() is passed
|
||||
if dim == () or dim == []:
|
||||
dim = None
|
||||
dim = torch._prims_common.reduction_dims(a.shape, dim)
|
||||
|
||||
# For complex tensors eager computes the variance as the sum of variances of
|
||||
# the real and imaginary parts
|
||||
# TODO: Creating a complex tensor from real and imaginary parts is not supported
|
||||
if torch._prims_common.is_complex_dtype(a.dtype):
|
||||
raise NotImplementedError("Complex tensors are not supported")
|
||||
|
||||
var_mean = prim(a, dim, correction=correction)
|
||||
|
||||
if keepdim:
|
||||
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
|
||||
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
|
||||
var, mean = var_mean
|
||||
var = torch._ops.ops.nvprims.broadcast_in_dim(
|
||||
var, output_shape, broadcast_dims
|
||||
)
|
||||
mean = torch._ops.ops.nvprims.broadcast_in_dim(
|
||||
mean, output_shape, broadcast_dims
|
||||
)
|
||||
var_mean = (var, mean)
|
||||
return var_mean
|
||||
|
||||
def _var_mean_autograd(
|
||||
a, dim=None, unbiased=None, keepdim=False, *, correction=None
|
||||
):
|
||||
# This wrapper is needed to convert prims calls inside
|
||||
# elementwise_type_promotion_wrapper to nvprims calls
|
||||
from torch._prims.context import NvfuserPrimsMode
|
||||
|
||||
with NvfuserPrimsMode():
|
||||
return backwards_not_supported(_var_mean_ref)(
|
||||
a, dim, unbiased, keepdim, correction=correction
|
||||
)
|
||||
|
||||
nvprim_autograd_impl.impl(name, _var_mean_autograd)
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
|
||||
p.impl_nvfuser = _nvfuser_impls["var_mean"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _nvprims_view_impl_aten(a, original_shape, new_shape):
|
||||
return a.reshape(new_shape)
|
||||
|
||||
|
||||
def register_view():
|
||||
"""This function is used to register the view function in torch.ops.view module."""
|
||||
# View is implemented as a decomposition into prims.split_dim,
|
||||
# prims.collapse_dim, and prims.reshape, but we would like to intercept
|
||||
# non-decomposed view for now
|
||||
name = "view"
|
||||
|
||||
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
|
||||
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
|
||||
|
||||
# This function is used under _AutoDispatchBelowAutograd context
|
||||
def _prim_impl(a, original_shape, new_shape):
|
||||
return a.reshape(new_shape)
|
||||
|
||||
nvprim_impl.impl(name, _prim_impl)
|
||||
|
||||
prim_packet = torch._ops.ops.nvprims.view
|
||||
prim = prim_packet.default
|
||||
|
||||
def _view_no_original_shape_overload_impl(a, shape):
|
||||
if list(a.shape) == list(shape):
|
||||
return torch.ops.nvprims.view_of(a)
|
||||
return torch.ops.nvprims.view.default(a, a.shape, shape)
|
||||
|
||||
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
|
||||
p.impl_nvfuser = _nvfuser_impls["view"]
|
||||
p.is_recomputable = _nvfuser_is_recomputable["view"]
|
||||
p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
|
||||
p.impl_aten = _nvprims_view_impl_aten
|
||||
|
||||
|
||||
def register_nvprims():
|
||||
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
|
||||
register_var_mean()
|
||||
register_view()
|
||||
register_native_batch_norm()
|
||||
register_rand_like()
|
||||
register_full()
|
||||
|
||||
for name in nvprim_names:
|
||||
main_prim = getattr(torch._ops.ops.prims, name)
|
||||
|
||||
nvprim.define(main_prim.schema)
|
||||
nvprim_impl.impl(name, main_prim.prim_impl)
|
||||
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
|
||||
|
||||
prim_packet = getattr(torch._ops.ops.nvprims, name)
|
||||
prim = prim_packet.default
|
||||
|
||||
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
|
||||
|
||||
for p in (prim_packet, prim):
|
||||
p.__doc__ = main_prim.__doc__
|
||||
p.impl_nvfuser = _nvfuser_impls[name]
|
||||
p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
|
||||
p.return_type = main_prim.return_type # type: ignore[attr-defined]
|
||||
p.impl_aten = main_prim.impl_aten
|
@ -25,39 +25,6 @@ import sympy
|
||||
import torch
|
||||
from torch import sym_float, sym_int, sym_max
|
||||
|
||||
try:
|
||||
try:
|
||||
from nvfuser import DataType # type: ignore[import, attr-defined]
|
||||
except ImportError:
|
||||
from nvfuser._C import DataType # type: ignore[import]
|
||||
|
||||
_torch_dtype_to_nvfuser_dtype_map = {
|
||||
torch.cdouble: DataType.ComplexDouble,
|
||||
torch.cfloat: DataType.ComplexFloat,
|
||||
torch.double: DataType.Double,
|
||||
torch.float: DataType.Float,
|
||||
torch.half: DataType.Half,
|
||||
torch.bfloat16: DataType.BFloat16,
|
||||
torch.long: DataType.Int,
|
||||
torch.int: DataType.Int32,
|
||||
torch.uint8: DataType.Int32,
|
||||
torch.bool: DataType.Bool,
|
||||
# Python scalars
|
||||
complex: DataType.ComplexDouble,
|
||||
float: DataType.Double,
|
||||
int: DataType.Int,
|
||||
bool: DataType.Bool,
|
||||
}
|
||||
except ImportError:
|
||||
_torch_dtype_to_nvfuser_dtype_map = {}
|
||||
|
||||
|
||||
def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]):
|
||||
"""
|
||||
Translates from torch.dtype to nvFuser's DataType enum
|
||||
"""
|
||||
return _torch_dtype_to_nvfuser_dtype_map[dtype]
|
||||
|
||||
|
||||
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
StrideType = Union[List[int], Tuple[int, ...]]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -50,14 +50,12 @@ class SpectralFuncPythonRefInfo(SpectralFuncInfo):
|
||||
op=None, # the function variant of the operation, populated as torch.<name> if None
|
||||
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
||||
torch_opinfo_variant="",
|
||||
supports_nvfuser=True,
|
||||
**kwargs,
|
||||
): # additional kwargs override kwargs inherited from the torch opinfo
|
||||
self.torch_opinfo_name = torch_opinfo_name
|
||||
self.torch_opinfo = _find_referenced_opinfo(
|
||||
torch_opinfo_name, torch_opinfo_variant, op_db=op_db
|
||||
)
|
||||
self.supports_nvfuser = supports_nvfuser
|
||||
assert isinstance(self.torch_opinfo, SpectralFuncInfo)
|
||||
|
||||
inherited = self.torch_opinfo._original_spectral_func_args
|
||||
@ -652,103 +650,83 @@ python_ref_db: List[OpInfo] = [
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fft",
|
||||
torch_opinfo_name="fft.fft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ifft",
|
||||
torch_opinfo_name="fft.ifft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.rfft",
|
||||
torch_opinfo_name="fft.rfft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.irfft",
|
||||
torch_opinfo_name="fft.irfft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.hfft",
|
||||
torch_opinfo_name="fft.hfft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ihfft",
|
||||
torch_opinfo_name="fft.ihfft",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fftn",
|
||||
torch_opinfo_name="fft.fftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ifftn",
|
||||
torch_opinfo_name="fft.ifftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.rfftn",
|
||||
torch_opinfo_name="fft.rfftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.irfftn",
|
||||
torch_opinfo_name="fft.irfftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.hfftn",
|
||||
torch_opinfo_name="fft.hfftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ihfftn",
|
||||
torch_opinfo_name="fft.ihfftn",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fft2",
|
||||
torch_opinfo_name="fft.fft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ifft2",
|
||||
torch_opinfo_name="fft.ifft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.rfft2",
|
||||
torch_opinfo_name="fft.rfft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.irfft2",
|
||||
torch_opinfo_name="fft.irfft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.hfft2",
|
||||
torch_opinfo_name="fft.hfft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.ihfft2",
|
||||
torch_opinfo_name="fft.ihfft2",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.fft.fftshift",
|
||||
op_db=op_db,
|
||||
torch_opinfo_name="fft.fftshift",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.fft.ifftshift",
|
||||
op_db=op_db,
|
||||
torch_opinfo_name="fft.ifftshift",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
]
|
||||
|
@ -2406,22 +2406,18 @@ python_ref_db: List[OpInfo] = [
|
||||
"_refs.linalg.diagonal",
|
||||
torch_opinfo_name="linalg.diagonal",
|
||||
supports_out=False,
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ReductionPythonRefInfo(
|
||||
"_refs.linalg.vector_norm",
|
||||
torch_opinfo_name="linalg.vector_norm",
|
||||
supports_out=True,
|
||||
supports_nvfuser=False, # clone_default
|
||||
op_db=op_db,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.linalg.matrix_norm",
|
||||
torch_opinfo_name="linalg.matrix_norm",
|
||||
supports_out=True,
|
||||
# Uses svdvals which does not support nvfuser
|
||||
supports_nvfuser=False,
|
||||
# Uses vector_norm inside and vector_norm is affected by
|
||||
# https://github.com/pytorch/pytorch/issues/77216
|
||||
validate_view_consistency=False,
|
||||
@ -2431,8 +2427,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"_refs.linalg.norm",
|
||||
torch_opinfo_name="linalg.norm",
|
||||
supports_out=True,
|
||||
# Uses svdvals which does not support nvfuser
|
||||
supports_nvfuser=False,
|
||||
# Uses vector_norm inside and vector_norm is affected by
|
||||
# https://github.com/pytorch/pytorch/issues/77216
|
||||
validate_view_consistency=False,
|
||||
@ -2442,14 +2436,12 @@ python_ref_db: List[OpInfo] = [
|
||||
"_refs.linalg.svd",
|
||||
torch_opinfo_name="linalg.svd",
|
||||
supports_out=True,
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.linalg.svdvals",
|
||||
torch_opinfo_name="linalg.svdvals",
|
||||
supports_out=True,
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
]
|
||||
|
@ -690,67 +690,56 @@ python_ref_db: List[OpInfo] = [
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.bessel_j0",
|
||||
torch_opinfo_name="special.bessel_j0",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.bessel_j1",
|
||||
torch_opinfo_name="special.bessel_j1",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.entr",
|
||||
torch_opinfo_name="special.entr",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.erfcx",
|
||||
torch_opinfo_name="special.erfcx",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.i0e",
|
||||
torch_opinfo_name="special.i0e",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.i1",
|
||||
torch_opinfo_name="special.i1",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.i1e",
|
||||
torch_opinfo_name="special.i1e",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.log_ndtr",
|
||||
torch_opinfo_name="special.log_ndtr",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.ndtr",
|
||||
torch_opinfo_name="special.ndtr",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.ndtri",
|
||||
torch_opinfo_name="special.ndtri",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.spherical_bessel_j0",
|
||||
torch_opinfo_name="special.spherical_bessel_j0",
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
),
|
||||
#
|
||||
@ -760,7 +749,6 @@ python_ref_db: List[OpInfo] = [
|
||||
"_refs.special.zeta",
|
||||
torch_opinfo_name="special.zeta",
|
||||
supports_one_python_scalar=True,
|
||||
supports_nvfuser=False,
|
||||
op_db=op_db,
|
||||
skips=(
|
||||
# Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu
|
||||
|
@ -100,7 +100,6 @@ class PythonRefInfo(OpInfo):
|
||||
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
||||
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
||||
validate_view_consistency=True,
|
||||
supports_nvfuser=True,
|
||||
**kwargs,
|
||||
): # additional kwargs override kwargs inherited from the torch opinfo
|
||||
self.torch_opinfo_name = torch_opinfo_name
|
||||
@ -109,7 +108,6 @@ class PythonRefInfo(OpInfo):
|
||||
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
||||
)
|
||||
self.validate_view_consistency = validate_view_consistency
|
||||
self.supports_nvfuser = supports_nvfuser
|
||||
assert isinstance(self.torch_opinfo, OpInfo)
|
||||
|
||||
inherited = self.torch_opinfo._original_opinfo_args
|
||||
@ -130,7 +128,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
|
||||
op_db=None, # The database of opinfos to search for the parent opinfo
|
||||
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
||||
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
||||
supports_nvfuser=True,
|
||||
**kwargs,
|
||||
): # additional kwargs override kwargs inherited from the torch opinfo
|
||||
self.torch_opinfo_name = torch_opinfo_name
|
||||
@ -138,7 +135,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
|
||||
self.torch_opinfo = _find_referenced_opinfo(
|
||||
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
||||
)
|
||||
self.supports_nvfuser = supports_nvfuser
|
||||
assert isinstance(self.torch_opinfo, ReductionOpInfo)
|
||||
|
||||
inherited = self.torch_opinfo._original_reduction_args
|
||||
@ -164,7 +160,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
|
||||
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
||||
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
||||
validate_view_consistency=True,
|
||||
supports_nvfuser=True,
|
||||
**kwargs,
|
||||
): # additional kwargs override kwargs inherited from the torch opinfo
|
||||
self.torch_opinfo_name = torch_opinfo_name
|
||||
@ -173,7 +168,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
|
||||
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
||||
)
|
||||
self.validate_view_consistency = validate_view_consistency
|
||||
self.supports_nvfuser = supports_nvfuser
|
||||
assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
|
||||
|
||||
inherited = self.torch_opinfo._original_unary_ufunc_args
|
||||
@ -195,7 +189,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
|
||||
op_db=None, # The database of opinfos to search for the parent opinfo
|
||||
torch_opinfo_name, # the string name of the corresponding torch opinfo
|
||||
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
|
||||
supports_nvfuser=True,
|
||||
**kwargs,
|
||||
): # additional kwargs override kwargs inherited from the torch opinfo
|
||||
self.torch_opinfo_name = torch_opinfo_name
|
||||
@ -203,7 +196,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
|
||||
self.torch_opinfo = _find_referenced_opinfo(
|
||||
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
|
||||
)
|
||||
self.supports_nvfuser = supports_nvfuser
|
||||
assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
|
||||
|
||||
inherited = self.torch_opinfo._original_binary_ufunc_args
|
||||
|
Reference in New Issue
Block a user