Remove dynamo+nvfuser (#105789)

This PR removes unmaintained Dynamo+nvFuser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105789
Approved by: https://github.com/jansel, https://github.com/jjsjann123, https://github.com/albanD
This commit is contained in:
Ivan Yashchuk
2023-08-08 13:29:31 +00:00
committed by PyTorch MergeBot
parent ad22f0ffb4
commit 6030151d37
20 changed files with 11 additions and 3135 deletions

View File

@ -52,10 +52,6 @@ Some of the most commonly used backends include:
- Description
* - ``torch.compile(m, backend="inductor")``
- Uses the TorchInductor backend. `Read more <https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747>`__
* - ``torch.compile(m, backend="aot_ts_nvfuser")``
- nvFuser with AOT Autograd/TorchScript. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* - ``torch.compile(m, backend="nvprims_nvfuser")``
- Tracing with nvFuser and its primitives. `Read more <https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593>`__
* - ``torch.compile(m, backend="cudagraphs")``
- CUDA graphs with AOT Autograd. `Read more <https://github.com/pytorch/torchdynamo/pull/757>`__

View File

@ -93,7 +93,7 @@ hub.
And that is not the only available backend, you can run in a REPL
``torch.compile.list_backends()`` to see all the available backends. Try out the
``cudagraphs`` or ``nvfuser`` next as inspiration.
``cudagraphs`` next as inspiration.
Using a pretrained model
~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -202,7 +202,7 @@ If the error does not occur with the ``"eager"`` backend, then the
backend compiler is the source of the error (`example
error <https://gist.github.com/mlazos/2f13681e3cc6c43b3911f336327032de%5D>`__).
There are `different choices <./torch.compiler.rst>`__
for backend compilers for TorchDynamo, with TorchInductor or nvfuser
for backend compilers for TorchDynamo, with TorchInductor
fitting the needs of most users. This section focuses on TorchInductor
as the motivating example, but some tools can also be used with other
backend compilers.

View File

@ -10,7 +10,6 @@ from torch._dynamo.backends.debugging import ExplainWithBackend
from torch._dynamo.backends.onnxrt import has_onnxruntime
from torch._dynamo.backends.tvm import has_tvm
from torch._dynamo.testing import same
from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
@ -123,21 +122,6 @@ class TestOptimizations(torch._dynamo.test_case.TestCase):
def test_aot_cudagraphs(self):
self._check_backend_works("cudagraphs")
@skipIfRocm
@requires_cuda()
def test_aot_ts_nvfuser(self):
self._check_backend_works("aot_ts_nvfuser")
@requires_cuda()
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
def test_nvprims_nvfuser(self):
self._check_backend_works("nvprims_nvfuser")
@requires_cuda()
@unittest.skipIf(IS_FBCODE, "BackendCompilerError")
def test_nvprims_aten(self):
self._check_backend_works("nvprims_aten")
@unittest.skipIf(not has_onnxruntime(), "requires onnxruntime")
def test_onnxrt(self):
self._check_backend_works("onnxrt")

View File

@ -1,11 +0,0 @@
# Owner(s): ["module: nvfuser"]
try:
from _nvfuser.test_dynamo import * # noqa: F403,F401
except ImportError:
def run_tests():
return
pass
if __name__ == '__main__':
run_tests()

View File

@ -484,26 +484,9 @@ class TestCommon(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyCUDA
@ops(python_ref_db)
@parametrize('executor', ['aten', 'nvfuser'])
@parametrize('executor', ['aten',])
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_executor(self, device, dtype, op, executor):
# TODO: Not all dtypes are supported with nvfuser
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map:
raise unittest.SkipTest(f"nvfuser doesn't support dtype {dtype}")
# nvFuser tests are rather slow so we only run int32 and float32 types
if executor == "nvfuser" and dtype not in [torch.int32, torch.float32]:
raise unittest.SkipTest("skipped for speed")
if executor == "nvfuser" and not op.supports_nvfuser:
raise unittest.SkipTest(f"{op.name} doesn't support nvfuser")
# nvFuser doesn't support reduction operations on 0-dim tensors yet
skip_zero_dim = False
if executor == "nvfuser" and isinstance(op, ReductionPythonRefInfo):
skip_zero_dim = True
# skip zero-dim tensors for some composites of reduction operations and view
skip_zero_dim_ops = [
"_refs.logsumexp",
@ -513,25 +496,16 @@ class TestCommon(TestCase):
"_refs.sum_to_size",
"ops.nvprims.view",
]
if executor == "nvfuser" and op.name in skip_zero_dim_ops:
skip_zero_dim = True
from torch._prims.executor import make_traced
from copy import copy
op = copy(op)
executor = "strictly_nvfuser" if executor == "nvfuser" else executor
op.op = partial(make_traced(op.op), executor=executor)
self._ref_test_helper(
contextlib.nullcontext,
device,
dtype,
op,
skip_zero_numel=("nvfuser" in executor), # nvfuser doesn't support zero-sized tensors
skip_zero_dim=skip_zero_dim,
skip_bfloat=("nvfuser" in executor), # nvfuser doesn't support bfloat tensors for pre-11 cuda TK
# # nvfuser doesn't support view consistency
# https://github.com/pytorch/pytorch/issues/84863
skip_view_consistency=("nvfuser" in executor),
)
@skipMeta

View File

@ -2,14 +2,12 @@
from functools import partial
from itertools import product
import warnings
from warnings import catch_warnings
import unittest
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
set_default_dtype, skipCUDAMemoryLeakCheckIf)
set_default_dtype)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@ -28,7 +26,6 @@ import torch._prims as prims
from torch._prims_common import CUDARngStateHelper
from torch._prims.executor import make_traced
import torch._refs as refs
from torch.fx.experimental.proxy_tensor import make_fx
if TEST_SCIPY:
@ -47,7 +44,7 @@ class TestPrims(TestCase):
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'strictly_nvfuser'):
for executor in ('aten',):
fn = partial(traced, executor=executor)
# Same shape
shape = (5, 5)
@ -87,20 +84,6 @@ class TestPrims(TestCase):
self.assertEqual(result.shape, b.shape)
self.assertEqual(a.unsqueeze(2), result)
# FIXME: This test exposes an issue in nvfuser
# Adds outermost, expands, and unsqueezes
"""
a = make_arg((1, 2, 3))
b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0)
result = fn(a, b, (1, 3, 4))
self.assertEqual(result.shape, b.shape)
a.unsqueeze_(3)
a.unsqueeze_(1)
a.unsqueeze_(0)
self.assertEqual(a.expand_as(result), result)
"""
@onlyCUDA
@dtypes(torch.float32)
def test_broadcast_in_dim_sum(self, device, dtype):
@ -112,7 +95,7 @@ class TestPrims(TestCase):
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'strictly_nvfuser'):
for executor in ('aten',):
fn = partial(traced, executor=executor)
shape = (5, 5)
a = make_arg(shape)
@ -171,340 +154,6 @@ class TestPrims(TestCase):
with self.assertRaises(AssertionError):
fn(t, start, end)
@onlyCUDA
def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used
# Assuming one-to-one mapping between prims and nvfuser implementations
# This test is not intended to test the correctness of the nvfuser implementation
try:
from nvfuser import FusionDefinition as fd
except ImportError:
from nvfuser._C import FusionDefinition as fd
prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
ops_without_nvfuser_impl = {
name
for name in prim_nvfuser_ops
if getattr(torch.ops.nvprims, name, None) is None
}
assert (
len(ops_without_nvfuser_impl) == 0
), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ",
"while there exists nvfuser implementations for them.")
def test_skip_ops_nvfuser_prims_mode(self, device):
# This test verifies that the NvfuserPrimsMode skips the specified
# functions. Skipping a function means that it's not converted into
# nvprims counterparts.
from torch._prims.context import NvfuserPrimsMode
a = make_tensor(5, 5, device=device, dtype=torch.float32)
def func(a):
return torch.ops.prims.sin.default(a)
skip_ops = {"prims.sin.default", }
with NvfuserPrimsMode(skip_ops=skip_ops):
gm = make_fx(func)(a)
includes_any_prims_sin = any(
node.target == torch.ops.prims.sin.default for node in gm.graph.nodes
)
self.assertTrue(includes_any_prims_sin)
include_any_nvprims_sin = any(
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
)
self.assertFalse(include_any_nvprims_sin)
def test_skip_ops_nvfuser_capability_mode(self, device):
# This test verifies that the NvfuserCapabilityMode skips the specified
# functions. Skipping a function means that specific
# reference/decomposition is not traced and there's no attempt to lower
# it to nvprims.
from torch._prims.context import TorchRefsNvfuserCapabilityMode
a = make_tensor(5, 5, device=device, dtype=torch.float32)
def func(a):
return torch.sin(a)
skip_ops = {"torch.sin", }
with TorchRefsNvfuserCapabilityMode(skip_ops=skip_ops):
gm = make_fx(func)(a)
includes_any_aten_sin = any(
node.target == torch.ops.aten.sin.default for node in gm.graph.nodes
)
self.assertTrue(includes_any_aten_sin)
include_any_nvprims_sin = any(
node.target == torch.ops.nvprims.sin.default for node in gm.graph.nodes
)
self.assertFalse(include_any_nvprims_sin)
def test_partitioner_tuple_output(self, device):
# This test verifies that the partitioner doesn't segment on nodes with
# tuple outputs.
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
a = make_tensor(5, 3, 3, device=device, dtype=torch.float32)
def func(x):
xx = torch.ops.nvprims.add(x, 1)
var, mean = torch.ops.nvprims.var_mean(x, correction=0)
var_cos = torch.ops.nvprims.cos(var)
mean_sin = torch.ops.nvprims.sin(mean)
return torch.ops.nvprims.add(var_cos, mean_sin)
gm = make_fx(func)(a)
supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
self.assertEqual(len(partitions), 1)
@onlyCUDA
@dtypes(torch.float32)
def test_full(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
def func1(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device),)
def func2(size, value, b):
a = torch.full(size, value, dtype=dtype, device=device)
b_sin = b.sin()
return (torch.add(a, b_sin),)
def func3(size, value, b):
return (torch.full(size, value, dtype=dtype, device=device), b)
def func4(size, value, b):
b_sin = b.sin()
return (torch.full(size, value, dtype=dtype, device=device), b_sin)
def func5(size, value, b):
b_sin = b.sin()
a = torch.full(size, value, dtype=dtype, device=device)
a_sin = a.sin()
return (a, b_sin, a_sin)
for func in (func1, func3, func2, func3, func4, func5):
size = (3, 3)
value = 10
b = torch.randn(*size, dtype=dtype, device=device)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(size, value, b)
out = execute(gm, size, value, b, executor="strictly_nvfuser")
self.assertEqual(out, func(size, value, b))
@onlyCUDA
def test_nvfuser_empty_fusion(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
def func(a, b, c):
return (a, b, c)
gm = make_fx(func)(a, a, a)
with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"):
execute(gm, a, a, a, executor="strictly_nvfuser")
# Should pass with partitioned executor
out = execute(gm, a, a, a, executor="nvfuser")
self.assertEqual(out, (a, a, a))
@onlyCUDA
@dtypes(torch.float16, torch.uint8)
def test_nvprim_convert_element_type(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
# initialize input as float32, which is different from `dtype` in the argument.
# this ensures that tracing will have a _to_copy node.
a = torch.randn(3, 3, device=device, dtype=torch.float32)
def func(x, dtype):
return x.to(dtype).to(x.dtype)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a, dtype)
execute(gm, a, dtype, executor="nvfuser")
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_to_copy = any(
torch.ops.aten._to_copy.default == node.target
for node in call_function_nodes
)
includes_nvprim_convert_element_type = any(
torch.ops.nvprims.convert_element_type.default == node.target
for node in call_function_nodes
)
nvprim_support_flag = _torch_dtype_to_nvfuser_dtype_map.get(dtype) is not None
self.assertEqual(includes_aten_to_copy, not nvprim_support_flag)
self.assertEqual(includes_nvprim_convert_element_type, nvprim_support_flag)
@onlyCUDA
def test_nvfuser_rand_like_fusion(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
def func(a):
return torch.rand_like(a)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
out = execute(gm, a, executor="strictly_nvfuser")
self.assertEqual(out.size(), a.size())
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
@onlyCUDA
def test_nvfuser_no_args(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
from torch._prims.nvfuser_executor import make_nvfuser_fusion
a = torch.randn(3, 3, device=device)
def func():
return torch.sigmoid(a)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)()
with warnings.catch_warnings(record=True) as caught:
execute(gm, executor="strictly_nvfuser")
# fusion execute with no cuda input is handled by nvprim aten fallback
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
make_nvfuser_fusion(gm)
with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
execute(gm, a, executor="strictly_nvfuser")
# Should pass with partitioned executor
out = execute(gm, executor="nvfuser")
self.assertEqual(out, func())
@onlyCUDA
def test_nvfuser_constant_tensors(self, device):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
b = torch.randn(3, 3, device=device)
def func(b):
return a + b
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(b)
with self.assertRaisesRegex(AssertionError, "not supported yet"):
execute(gm, b, executor="strictly_nvfuser")
# Should pass with partitioned executor
out = execute(gm, b, executor="nvfuser")
self.assertEqual(out, gm(b))
@onlyCUDA
def test_nvfuser_executor_cached_noncontiguous(self, device):
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 3, device=device)
def func(a):
return torch.sigmoid(a)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
# First run to create the cache
execute(gm, a, executor="strictly_nvfuser")
# a.mT is noncontiguous, but it shouldn't affect correctness
expected = execute(gm, a.mT, executor="aten")
for use_python_cache in [True, False]:
params = {"use_python_fusion_cache": use_python_cache}
actual = execute(gm, a.mT, executor="strictly_nvfuser", executor_parameters=params)
self.assertEqual(expected, actual)
def test_nvfuser_capability_context(self, device):
# This test is to ensure that the torch calls are replaced with refs
# based on the nvfuser+prims capability
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
# It's assumed that digamma is not supported by nvfuser
# If it's ever supported, this test will need to be updated
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
a = torch.randn(3, 3, device=device)
def func(a):
return torch.digamma(a)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
# Check that the torch.digamma is not replaced with torch.ops.prims.digamma
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_digamma = any(
torch.ops.aten.digamma.default == node.target
for node in call_function_nodes
)
includes_prims_digamma = any(
torch.ops.prims.digamma.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_aten_digamma)
self.assertFalse(includes_prims_digamma)
# Check mixed case, sigmoid is replaced with refs, but digamma is not
def func(a):
return torch.sigmoid(torch.digamma(a))
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_sigmoid = any(
torch.ops.aten.sigmoid.default == node.target
for node in call_function_nodes
)
includes_prims_digamma = any(
torch.ops.prims.digamma.default == node.target
for node in call_function_nodes
)
includes_nvprims_exp = any(
torch.ops.nvprims.exp.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_sigmoid)
self.assertFalse(includes_prims_digamma)
self.assertTrue(includes_nvprims_exp)
def test_aten_overload_to_prims(self, device):
# This test is to ensure that the torch.ops.aten calls are replaced with refs
@ -526,325 +175,6 @@ class TestPrims(TestCase):
)
self.assertTrue(all_prims_namespace)
@onlyCUDA
def test_nvfuser_executor_parameters(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.executor import execute
a = torch.randn(3, 4, device=device)
def func(a):
return torch.ops.nvprims.add(a, a)
gm = make_fx(func)(a)
expected = execute(gm, a, executor="aten")
# Shouldn't raise an error because unuseful parameters are ignored
params_dicts = [None, {}, {"none": None}]
for params in params_dicts:
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
self.assertEqual(expected, actual)
# Check caching parameter
for use_cache in [True, False]:
params = {"use_python_fusion_cache": use_cache}
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
self.assertEqual(expected, actual)
# Check allow_single_op_fusion parameter
for allow_single_op_fusion in [True, False]:
params = {"allow_single_op_fusion": allow_single_op_fusion}
actual = execute(gm, a, executor="nvfuser", executor_parameters=params)
self.assertEqual(expected, actual)
@onlyCUDA
def test_nvfuser_executor_partitioned(self, device):
# This test is to ensure that nvfuser partitioned executor works correctly
# It's assumed that digamma is not supported by nvfuser
# If it's ever supported, this test will need to be updated
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 4, device=device)
b = torch.rand(3, 1, device=device)
c = torch.rand(3, 4, device=device)
def func(a, b, c):
aa = torch.digamma(a) # not supported by nvfuser
d = torch.add(b, c)
dd = torch.sqrt(d)
return torch.mul(aa, dd.digamma())
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a, b, c)
expected = execute(gm, a, b, c, executor="aten")
actual = execute(gm, a, b, c, executor="nvfuser")
self.assertEqual(expected, actual)
@onlyCUDA
def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
# This test is to ensure that nvfuser partitioned executor works correctly
# It's assumed that digamma is not supported by nvfuser
# If it's ever supported, this test will need to be updated
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
a = torch.randn(3, 4, device=device)
def func(a):
return torch.digamma(a) # not supported by nvfuser
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
with catch_warnings(record=True) as w:
# Trigger warning
execute(gm, a, executor="nvfuser")
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("is not supported by nvFuser" in str(w[-1].message))
def test_nvprims(self, device):
# This test is to ensure that nvfuser specific prims are exposed
# and can be traced with make_fx
from torch.fx.experimental.proxy_tensor import make_fx
def func(a):
return torch.ops.nvprims.add(a, a)
a = torch.randn(3, 4, device=device)
gm = make_fx(func)(a)
for node in gm.graph.nodes:
if node.op == "call_function":
self.assertTrue(node.name == "add")
self.assertTrue(node.target == torch.ops.nvprims.add.default)
self.assertFalse(node.target == torch.ops.prims.add.default)
self.assertFalse(node.target == torch.ops.aten.add.default)
@onlyCUDA
@dtypes(torch.float32, torch.float64)
def test_native_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
# This test verifies that native_batch_norm is translated into nvprims
# and can be executed with nvFuser
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import (
sample_inputs_native_batch_norm,
)
samples = sample_inputs_native_batch_norm(
None, device, dtype, requires_grad=False
)
batch_norms = [
torch.native_batch_norm,
torch.ops.aten.native_batch_norm,
torch.ops.aten.native_batch_norm.default,
torch.ops.nvprims.native_batch_norm.default,
]
for sample, batch_norm in product(samples, batch_norms):
if sample.input.numel() == 0:
continue
def func(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(sample.input, *sample.args)
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
includes_aten_batch_norm = any(
torch.ops.aten.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_batch_norm)
includes_nvprims_batch_norm = any(
torch.ops.nvprims.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_batch_norm)
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="strictly_nvfuser")
self.assertEqual(out, gm(sample.input, *sample.args))
@onlyCUDA
@dtypes(torch.float32, torch.float64)
def test_cudnn_batch_norm_nvprims(self, device, dtype):
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
# This test verifies that cudnn_batch_norm is translated into nvprims
# and can be executed with nvFuser
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_methods_invocations import (
sample_inputs_native_batch_norm,
)
samples = sample_inputs_native_batch_norm(
None, device, dtype, requires_grad=False
)
for sample in samples:
if sample.input.numel() == 0:
continue
def func(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.ops.aten.cudnn_batch_norm.default(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(sample.input, *sample.args)
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
includes_aten_batch_norm = any(
torch.ops.aten.cudnn_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_batch_norm)
includes_nvprims_batch_norm = any(
torch.ops.nvprims.native_batch_norm.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_batch_norm)
# Check that the graph can be executed with nvFuser
out = execute(gm, sample.input, *sample.args, executor="nvfuser")
ref_out = gm(sample.input, *sample.args)
for idx, (left, right) in enumerate(zip(out, ref_out)):
# Nvfuser does not support torch.uint8 dtype so check reserve output against 0 scalar
if idx == 3:
self.assertTrue(torch.all(torch.eq(left, 0)))
else:
self.assertEqual(left, right)
# decomposition of native_batch_norm_backward uses a casting, which prevents nvprim lowering on CPU build
@onlyCUDA
@dtypes(torch.float32, torch.float16)
def test_batch_norm_backward_nvprims(self, device, dtype):
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
samples_iter = sample_inputs_batch_norm(None, device, dtype, requires_grad=True)
sample = next(samples_iter)
grad = torch.randn_like(sample.input)
def func1(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.native_batch_norm_backward.default(
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
)
def func2(grad, input, weight, rm, rv, eps, train):
return torch.ops.aten.cudnn_batch_norm_backward.default(
input, grad, weight, rm, rv, rm, rv, eps, grad
)
args = sample.args
kwargs = sample.kwargs
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
for func in (func1, func2):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(*all_args)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_batch_norm_backward = any(
torch.ops.aten.native_batch_norm_backward.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_batch_norm_backward)
all_nvprims = all(
str(node.target).startswith("nvprims") for node in call_function_nodes
)
self.assertTrue(all_nvprims)
@onlyCUDA
@dtypes(torch.float32)
def test_silu_backward_no_filled_tensor(self, device, dtype):
# This test verifies a workaround for
# https://github.com/pytorch/pytorch/issues/86612
from torch.fx.experimental.proxy_tensor import make_fx
from functorch import functionalize
from torch._prims.nvfuser_executor import _remove_empty_like_fill
from torch._prims.context import TorchRefsNvfuserCapabilityMode
def func(a):
out = torch.nn.functional.silu(a)
grad = torch.ones_like(out)
return torch.autograd.grad([out], [a], [grad])
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=True)
a = make_arg((3, 4))
gm = make_fx(func)(a)
# functionalize(gm) doesn't work with non-detached inputs
gm = make_fx(functionalize(gm))(a.detach())
# replace aten.sub with nvprims.sub
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(a)
# Check that the graph contains empty_like
any_aten_empty_like = any(
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
)
self.assertTrue(any_aten_empty_like)
any_aten_fill = any(
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
)
self.assertTrue(any_aten_fill)
# Now remove the empty_like and fill
gm = _remove_empty_like_fill(gm)
any_aten_empty_like = any(
node.target == torch.ops.aten.empty_like.default for node in gm.graph.nodes
)
self.assertFalse(any_aten_empty_like)
any_aten_fill = any(
node.target == torch.ops.aten.fill.Scalar for node in gm.graph.nodes
)
self.assertFalse(any_aten_fill)
self.assertEqual(gm(a), func(a))
@onlyCUDA
@dtypes(torch.float32)
@parametrize("correction", [0, 1])
@ -855,7 +185,7 @@ class TestPrims(TestCase):
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
for executor in ('aten', 'strictly_nvfuser'):
for executor in ('aten',):
fn = partial(traced, executor=executor)
shape = (5, 5)
a = make_arg(shape)
@ -865,160 +195,6 @@ class TestPrims(TestCase):
self.assertTrue(result.is_contiguous)
self.assertEqual(_wrapper(a), result)
@onlyCUDA
@dtypes(torch.float16, torch.float32)
@parametrize("correction", [0, 1])
@parametrize("keepdim", [True, False])
def test_var_mean(self, device, dtype, correction, keepdim):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
def _wrapper(a):
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
make_arg = partial(make_tensor, device=device, dtype=dtype)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(_wrapper)(make_arg((5, 5)))
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_nvprims_var_mean = any(
torch.ops.nvprims.var_mean.main == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_var_mean)
@onlyCUDA
@dtypes(torch.float16, torch.float32)
def test_nvprims_view(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((3, 4, 5))
def func1(a):
return a.view(tuple(reversed(a.shape)))
def func2(a):
return a.reshape(tuple(reversed(a.shape)))
def func3(a):
return torch.view_copy(a, tuple(reversed(a.shape)))
def func4(a):
return torch.reshape(a, tuple(reversed(a.shape)))
def func5(a):
return torch.ops.aten.view.default(a, tuple(reversed(a.shape)))
def func6(a):
return torch.ops.aten._unsafe_view.default(a, tuple(reversed(a.shape)))
def func7(a):
return torch.ops.aten.view_copy.default(a, tuple(reversed(a.shape)))
for func in (func1, func2, func3, func4, func5, func6, func7):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_nvprims_view = any(
torch.ops.nvprims.view.default == node.target
for node in call_function_nodes
)
self.assertTrue(includes_nvprims_view)
# Try executing the graph
out = execute(gm, a, executor="strictly_nvfuser")
self.assertEqual(out, func(a))
@onlyCUDA
@dtypes(torch.float16, torch.float32)
def test_nvprims_view_partitioner(self, device, dtype):
# This test verifies that views that are not fused with other ops are
# correctly overriden to call aten implementation.
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.nvfuser_executor import maybe_partition_graph
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((4, 5))
b = make_arg((5, 4))
def func(a, b):
aa = a.view(b.shape)
aa = aa.view(a.shape)
return aa.digamma()
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(a, b)
gm, _ = maybe_partition_graph(gm, False, False)
out = gm(a, b)
self.assertEqual(out, func(a, b))
@onlyCUDA
@dtypes(torch.float32, torch.float16)
def test_cpu_tensor(self, device, dtype):
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
def _wrapper(t0, t1, cpu_scalar):
return t0 + t1 + cpu_scalar
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((12, 1))
b = make_arg((12, 12))
c = torch.tensor(0.5)
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(_wrapper)(a, b, c)
with warnings.catch_warnings(record=True) as caught:
actual = execute(gm, a, b, c, executor="nvfuser")
# cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback
self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
expected = execute(gm, a, b, c, executor="aten")
self.assertEqual(expected, actual)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_add = any(
torch.ops.aten.add.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_add)
with warnings.catch_warnings(record=True) as caught:
nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser")
# cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
self.assertEqual(expected, nvprim_aten_fallback)
@onlyCUDA
@dtypes(torch.float32)
def test_pytree_input_output(self, device, dtype):
@make_traced
def fn(a, b_dict):
b = b_dict["b"]
d = {}
d["c"] = torch.add(a, b)
return (d, torch.add(a, d["c"]))
make_arg = partial(make_tensor, device=device, dtype=dtype)
a = make_arg((5, 5))
b = make_arg((1, 5))
b_dict = {"b": b}
result_aten = fn(a, b_dict, executor="aten")
result_nvfuser = fn(a, b_dict, executor="strictly_nvfuser")
self.assertEqual(result_aten, result_nvfuser)
@dtypes(torch.float32)
def test_memory_format_strides(self, device, dtype):
shapes = (
@ -1227,70 +403,6 @@ instantiate_device_type_tests(TestRefs, globals())
class TestDecomp(TestCase):
@onlyCUDA
@dtypes(torch.float16, torch.float32)
def test_decomposition_type_promotion_nvprim_amp(self, device, dtype):
x = torch.rand(5, device=device).to(dtype)
y = torch.rand(5, device=device).to(dtype)
from torch._prims.context import TorchRefsNvfuserCapabilityMode, _is_func_unsupported_nvfuser
from torch.fx.experimental.proxy_tensor import make_fx
op = torch.ops.aten.leaky_relu_backward.default
op_decomp = torch._decomp.decomposition_table.get(op)
def fn0(*arg):
return _is_func_unsupported_nvfuser(TorchRefsNvfuserCapabilityMode(), op, op_decomp, arg, {})
def fn1(x):
x = x * 2
x = x @ x
x = x * 2
return x
self.assertFalse(fn0(x, y, 0.3, False))
with TorchRefsNvfuserCapabilityMode():
# Autocast context has C++ level ATen calls that are hidden from
# TorchRefsNvfuserCapabilityMode that works only on Python level.
# The first call to make_fx records autocast C++ calls directly and
# doesn't have the chance to translate to nvprims. After the first
# call, "gm" contains explicit calls to torch.ops.aten and nothing
# is hidden, so the second call to make_fx actually translates
# recorded autocast dtype conversions to nvprims.
with torch.autocast("cuda"):
gm = make_fx(fn1)(x)
gm = make_fx(gm)(x)
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
includes_aten_to_copy = any(
torch.ops.aten._to_copy.default == node.target
for node in call_function_nodes
)
self.assertFalse(includes_aten_to_copy)
@onlyCUDA
@dtypes(torch.float16, torch.float32)
def test_masked_fill_decomposition_under_nvprim_context(self, device, dtype):
# Test masked_fill decomposition doesn't trigger data-dependent control flow
# on TorchRefsNvfuser speculative lowering.
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.context import TorchRefsNvfuserCapabilityMode
x = torch.empty(2, 3, device=device).to(dtype=dtype)
mask = torch.ones_like(x).bool()
y = torch.tensor(0.3) # cpu scalar tensor
def func(x, mask, y):
return torch.masked_fill(x, mask, y)
# mimics real use-case for TorchRefsNvfuserCapabilityMode context
gm = make_fx(func, decomposition_table={})(x, mask, y)
with warnings.catch_warnings(record=True) as caught:
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(x, mask, y)
# masked_fill decomposition fails inside `get_isolated_graphmodule`
self.assertFalse(any(GET_ISOLATED_GRAPHMODULE_ERROR in str(w.message) for w in caught))
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
def test_decomposition_method_vararg(self, device, dtype, op):
# some ops have vararg variants for the methods. this tests it.

View File

@ -1,145 +0,0 @@
# Owner(s): ["module: nvfuser"]
import unittest
import warnings
from functools import partial
import torch
import torch._dynamo as torchdynamo
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.jit_utils import RUN_CUDA
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
def is_networkx_available():
try:
import networkx # noqa: F401
return True
except ImportError:
return False
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
class TestNvFuserDynamo(TestCase):
def test_basic(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
@torchdynamo.optimize("nvprims_nvfuser")
def func(a, b):
return a.sin() + b.cos()
# No warnings and no errors
with warnings.catch_warnings(record=True) as w:
nvfuser_result = func(input1, input2)
self.assertEqual(len(w), 0)
eager_result = func.__wrapped__(input1, input2)
self.assertEqual(eager_result, nvfuser_result)
@unittest.skipIf(not is_networkx_available(), "networkx not available")
def test_min_cut(self):
from functorch.compile import default_partition
from torch._dynamo.backends.nvfuser import nvprims_fw_bw_partition_fn
def get_fw_bw_graph(f, inps, partitioner):
from functorch.compile import aot_function
# Helper functions are taken from functorch/test_aotdispatch.py
def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
return fx_g
fw_graph_cell = [None]
bw_graph_cell = [None]
aot_function(
f,
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
partition_fn=partitioner,
)(*inps).sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])
def get_ins_outs(fx_g):
ins = []
outs = []
for n in fx_g.graph.nodes:
if n.op == "placeholder":
ins.append(n)
elif n.op == "output":
outs = tuple(n.args[0])
return ins, outs
def get_num_ins_outs(fx_g):
return tuple(len(i) for i in get_ins_outs(fx_g))
def func(x):
return x * x * x
input1 = make_tensor(
(3,), device="cpu", dtype=torch.float32, requires_grad=True
)
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition)
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
input1 = make_tensor(
(3,), device="cpu", dtype=torch.float32, requires_grad=True
)
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn)
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
def test_batch_norm_implicit_dtype_promotion(self):
input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32)
input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32)
w = make_tensor((3), device="cuda", dtype=torch.float32)
b = make_tensor((3), device="cuda", dtype=torch.float32)
@torchdynamo.optimize("nvprims_nvfuser")
def func(mat1, mat2, w, b):
o = torch.matmul(mat1, mat2)
return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True)
# No warnings and no errors
with torch.cuda.amp.autocast():
with warnings.catch_warnings(record=True) as warning:
nvfuser_result = func(input1, input2, w, b)
self.assertEqual(len(warning), 0)
eager_result = func.__wrapped__(input1, input2, w, b)
self.assertEqual(eager_result, nvfuser_result)
def test_dtype_correctness(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16)
@torchdynamo.optimize("nvprims_nvfuser")
def func(a):
tmp = a + 1.0
# nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back
return torch.where(tmp > 0, tmp, 0.0)
nvfuser_result = func(input1)
eager_result = func.__wrapped__(input1)
self.assertEqual(eager_result, nvfuser_result)
if __name__ == "__main__":
run_tests()

View File

@ -1,95 +0,0 @@
import logging
from functools import partial
import torch
from ..backends.common import aot_autograd, mem_efficient_fusion_kwargs
from .registry import register_backend, register_debug_backend
log = logging.getLogger(__name__)
def prims_executor(gm, inputs, *, executor):
from functorch.compile import make_boxed_func
# This function is called once per forward/backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph and return
# execute function.
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch._prims.executor import execute
from torch.fx.experimental.proxy_tensor import make_fx
# AOT Autograd might not use the partitioner, so we need to make sure that
# the graph is transformed to use nvFuser-compatible nodes.
if not getattr(gm, "_nvprim_transformed", False):
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(gm)(*inputs)
# Then we return a callable that executes the "gm" graph
return make_boxed_func(partial(execute, gm, executor=executor))
def nvprims_fw_bw_partition_fn(joint_module, joint_inputs, *, num_fwd_outputs):
# This function is called once per forward+backward pass of a graph in AOT
# Autograd. We use it to set up the nvFuser-specific FX graph that is later
# passed to the executor.
from functorch.compile import min_cut_rematerialization_partition
from torch._prims.context import TorchRefsNvfuserCapabilityMode
from torch.fx.experimental.proxy_tensor import make_fx
# AOT Autograd expects arguments of the traced function to be named exactly
# "primals, tangents"
def func(primals, tangents):
return joint_module(primals, tangents)
# First we trace the graph conditionally decomposing nodes
# that can be sent to the nvfuser executor
with TorchRefsNvfuserCapabilityMode():
prim_gm = make_fx(func)(*joint_inputs)
# all nvprims for now
recomputable_ops = {
getattr(torch.ops.nvprims, prim)
for prim in dir(torch.ops.nvprims)
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
and getattr(torch.ops.nvprims, prim).is_recomputable
}
fw_gm, bw_gm = min_cut_rematerialization_partition(
prim_gm,
joint_inputs,
recomputable_ops=recomputable_ops,
num_fwd_outputs=num_fwd_outputs,
)
# AOT Autograd might not use the partitioner, so we need to make sure that
# the graph is marked as already transformed to use nvFuser-compatible nodes
fw_gm._nvprim_transformed = True
bw_gm._nvprim_transformed = True
return fw_gm, bw_gm
def create_nvprims_backend(*, executor):
return aot_autograd(
fw_compiler=partial(prims_executor, executor=executor),
bw_compiler=partial(prims_executor, executor=executor),
partition_fn=nvprims_fw_bw_partition_fn,
)
aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser")
aot_nvprims_aten = create_nvprims_backend(executor="aten")
# "nvprims" is a subset of PrimTorch primitives that are guaranteed to be
# supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch.
register_backend(name="nvprims_nvfuser", compiler_fn=aot_nvprims_nvfuser)
# This is useful for debugging. Can be removed later.
register_debug_backend(name="nvprims_aten", compiler_fn=aot_nvprims_aten)
# Use min cut rematerialization and TorchScript+nvFuser with AOT Autograd
# aot_ts_nvfuser uses the memory efficient fusion algorithm from AOT Autograd.
# It uses min cut rematerialization algorithm, uses nvFuser as the
# compiler backend, and TorchScript as the frontend.
aot_mem_efficient_fusion = aot_autograd(**mem_efficient_fusion_kwargs(use_decomps=True))
aot_mem_efficient_fusion.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
register_backend(name="aot_ts_nvfuser", compiler_fn=aot_mem_efficient_fusion)

View File

@ -13,7 +13,6 @@ import torch.library
from torch import sym_float, Tensor, TypedStorage
from torch._C import _get_default_device
from torch._prims.debug_prims import register_debug_prims
from torch._prims.nvfuser_prims import register_nvprims
from torch._prims.rng_prims import register_rng_prims
from torch._prims_common import (
Dim,
@ -2964,6 +2963,5 @@ fft_c2r = _make_prim(
doc=_fft_c2r_doc,
)
register_nvprims()
register_rng_prims()
register_debug_prims()

View File

@ -1,7 +1,6 @@
import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Optional, Sequence
from warnings import warn
import torch
@ -13,10 +12,8 @@ import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
import torch.overrides
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch._prims_common import torch_function_passthrough
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
@functools.lru_cache(None)
@ -82,55 +79,6 @@ def all_prims():
return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
"""
Switches the interpretation of torch.ops.prims.* functions to
use nvFuser's prims in torch.ops.nvprims.*
>>> # xdoctest: +SKIP("undefined vars")
>>> with NvfuserPrimsMode():
... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
By default, this context manager will fall back on the torch.ops.prims* if the
nvprim does not exist.
It's possible to skip certain prims by passing their names to the skip_ops
argument. skip_ops is expected to be a sequence of strings, e.g.,
["prims.add.default"] In order to check the expected name of a prim, one can
use the `torch.overrides.resolve_name`.
>>> # xdoctest: +SKIP("undefined vars")
>>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
... torch.ops.prims.add.default(x, y) # does not call torch.ops.nvprims.add.default(x, y)
"""
def __init__(self, *, skip_ops=()):
self.skip_ops = skip_ops
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
# If the function is in the skip list, then we don't want to
# remap it to the nvprims.
if torch.overrides.resolve_name(orig_func) in self.skip_ops:
return orig_func(*args, **kwargs)
if isinstance(orig_func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
namespace = str(orig_func).split(".")[0]
name = str(orig_func).split(".")[1]
if namespace == "prims":
nvfunc = getattr(torch.ops.nvprims, name, None)
if nvfunc is not None:
return nvfunc(*args, **kwargs)
return orig_func(*args, **kwargs)
class TorchRefsMode(torch.overrides.TorchFunctionMode):
"""
Switches the interpretation of torch.* functions and Tensor methods to
@ -194,235 +142,3 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
)
return orig_func(*args, **kwargs)
def _is_node_supported_nvfuser(node):
return (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is not None
)
def _is_func_unsupported_nvfuser(
torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
):
"""
This function traces the `func` under `torch_function_mode` and checks if
any of the traced nodes are not supported by nvFuser. If so, we should
fallback to the original function.
`skip_ops` argument is expected to be a list of strings of function names
that would match with `torch.overrides.resolve_name`.
Args:
torch_function_mode: The torch_function_mode context manager. orig_func:
The original function, its name will be used to check if
it should be skipped.
func: The function to be traced. args: The args to be passed to the
function. kwargs: The kwargs to be passed to the function.
Keyword args:
skip_ops: A list of ops to skip when checking if the function is
supported.
"""
# One supported case is easy to check: if the resolved name of the original
# function in the skip list, skip it.
if torch.overrides.resolve_name(orig_func) in skip_ops:
return True
with torch_function_mode:
try:
gm = get_isolated_graphmodule(func, args, kwargs)
except Exception as e:
warn(
"get_isolated_graphmodule failed on decomposition: "
+ func.__name__
+ " with error message: "
+ str(e)
)
# returns unsupported when tracing fails.
return True
supported_ops = NvfuserPrimOperatorSupport()
call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
any_unsupported = any(
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
)
return any_unsupported
class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
def __init__(self, *, skip_ops=()):
aten_ops_to_skip = (
"aten._log_softmax.default",
"aten._log_softmax_backward_data.default",
"aten.expand.default",
)
self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
super().__init__(
strict=False,
should_fallback_fn=functools.partial(
_is_func_unsupported_nvfuser,
skip_ops=tuple(skip_ops) + aten_ops_to_skip,
),
prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
)
# TODO: remove this once version from _decomp/decompositions.py is working
# with this context manager
# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm(
self,
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
):
a, b, c = torch.ops.nvprims.native_batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
)
if training:
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
return (
a,
weight.new_zeros((0,)),
weight.new_zeros((0,)),
input.new_zeros((0,), dtype=torch.uint8),
)
# This is a workaround for AOT Autograd graphs
def _cudnn_batch_norm_backward(
self,
input,
grad_output,
weight,
running_mean,
running_var,
save_mean,
save_var,
epsilon,
reserveSpace,
):
func = torch._decomp.decomposition_table[
torch.ops.aten.native_batch_norm_backward.default
]
return func(
grad_output,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
True,
epsilon,
[True, True, True],
)
def _is_var_mean(self, func):
return "torch.var_mean" == torch.overrides.resolve_name(func) or (
(isinstance(func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)))
and "aten.var_mean" in str(func)
)
def _is_view_or_reshape(self, func):
allowed_ops = {
"torch.Tensor.view",
"torch.Tensor.reshape",
"torch.view_copy",
"torch.reshape",
"aten.view.default",
"aten._unsafe_view.default",
"aten.view_copy.default",
} - set(self.skip_ops)
return torch.overrides.resolve_name(func) in allowed_ops
def _is_native_batch_norm(self, func):
return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
func == torch.ops.aten.native_batch_norm.default
or func == torch.ops.aten.native_batch_norm
)
def _is_rand_like(self, func):
result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
)
return result
def _is_full(self, func):
result = "torch.full" == torch.overrides.resolve_name(func) or (
func
in [
torch.ops.aten.full,
torch.ops.aten.full.names,
]
)
return result
def __torch_function__(
self,
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
):
if kwargs is None:
kwargs = {}
# First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
if self._is_var_mean(orig_func):
return torch.ops.nvprims.var_mean(*args, **kwargs)
if (
orig_func == torch.ops.aten.cudnn_batch_norm.default
or orig_func == torch.ops.aten.cudnn_batch_norm
):
with self:
return self._cudnn_batch_norm(*args, **kwargs)
# A workaround for AOT Autograd graphs
# See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
if (
orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
or orig_func == torch.ops.aten.cudnn_batch_norm_backward
):
with self:
return self._cudnn_batch_norm_backward(*args, **kwargs)
if self._is_view_or_reshape(orig_func):
a, *shape = args
shape = torch._prims_common.extract_shape_from_varargs(
shape, validate=False
) # type: ignore[assignment]
if len(kwargs) > 0:
warn("view has ignored kwargs!")
return torch.ops.nvprims.view(a, shape)
if orig_func == torch.ops.aten._reshape_alias.default:
a, shape, stride = args
if len(kwargs) > 0:
warn("view has ignored kwargs!")
return torch.ops.nvprims.view(a, shape)
if self._is_native_batch_norm(orig_func):
return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
if self._is_rand_like(orig_func):
if len(kwargs) > 0:
warn("rand_like has ignored kwargs!")
return torch.ops.nvprims.rand_like(*args)
if self._is_full(orig_func):
return torch.ops.nvprims.full(*args, **kwargs)
# Then we use TorchRefsMode to interpret the rest
return super().__torch_function__(orig_func, types, args, kwargs)

View File

@ -1,7 +1,6 @@
from typing import Callable, Optional
from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
from torch._prims.context import TorchRefsMode
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
@ -21,14 +20,8 @@ def execute(
if executor == "aten":
return gm.forward(*args)
elif executor == "nvfuser":
return nvfuser_execute_partitioned(
gm, *args, executor_parameters=executor_parameters
)
elif executor == "strictly_nvfuser":
return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten, nvfuser."
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
raise ValueError(msg)
@ -53,16 +46,14 @@ def make_traced(fn: Callable):
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
Executor may be either 'aten' or 'nvfuser'.
result = traced_foo(a, b, executor='aten')
"""
def _traced(*args, executor="aten", **kwargs):
# TODO: caching
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
with NvfuserPrimsMode(), TorchRefsMode():
with TorchRefsMode():
gm = make_fx(wrapped)(all_args)
return execute(gm, all_args, executor=executor)

View File

@ -1,510 +0,0 @@
import operator
from copy import deepcopy
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from warnings import warn
import torch
import torch.fx
import torch.overrides
from torch._prims_common import (
_torch_dtype_to_nvfuser_dtype_map,
getnvFuserDtype,
Number,
number_type,
)
from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
if torch.cuda.is_available():
try:
from nvfuser import ( # type: ignore[attr-defined, import]
DataType,
FusionDefinition,
Tensor,
)
def create_fusion_definition():
fd = FusionDefinition()
return fd, fd
except ImportError:
from nvfuser._C import ( # type: ignore[import]
DataType,
Fusion,
FusionDefinition,
Tensor,
)
def create_fusion_definition():
fusion = Fusion()
return fusion, FusionDefinition(fusion)
else:
DataType = None
import os
@lru_cache(None)
def get_nvprim_dump_nvtx():
return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX")
DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
{
"use_python_fusion_cache": True,
"allow_single_op_fusion": False,
}
)
# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects
# https://github.com/pytorch/pytorch/issues/80551
@dataclass(frozen=True)
class nvFuserTensorTemplate:
symbolic_shape: tuple
contiguity: tuple
dtype: DataType
is_cpu: bool
@dataclass(frozen=True)
class nvFuserScalarTemplate:
dtype: DataType
@lru_cache(maxsize=2048)
def compute_symbolic_shape(shape):
"""Computes the symbolic shape of a tensor.
nvFuser specializes on size-1 dimensions as broadcasted dimensions.
-1 is used to represent any size."""
return tuple(1 if s == 1 else -1 for s in shape)
@lru_cache(maxsize=2048)
def compute_contiguity(shape, strides):
"""Computes the contiguity information to simplify internal indexing.
Contiguous dimensions are represented by True, strided dimensions
are represented by False.
"""
try:
from nvfuser import compute_contiguity # type: ignore[attr-defined]
except ImportError:
from nvfuser._C import compute_contiguity
return tuple(compute_contiguity(shape, strides))
def to_nvfuser_template_args(args):
def to_nvfuser(arg):
if isinstance(arg, torch.Tensor):
return nvFuserTensorTemplate(
compute_symbolic_shape(arg.size()),
compute_contiguity(arg.size(), arg.stride()),
getnvFuserDtype(arg.dtype),
arg.is_cpu, # type: ignore[attr-defined]
)
elif isinstance(arg, Number):
return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
else:
return arg
return tree_map(to_nvfuser, args)
def _any_get_attr_used(call_function_nodes):
return any(
filter(
# bug in mypy https://github.com/python/mypy/issues/12682
lambda n: any( # type: ignore[arg-type]
a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined]
),
call_function_nodes,
)
)
# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024) # type: ignore[arg-type]
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
if not torch.cuda.is_available():
raise RuntimeError(
"Attempting to use nvFuser trace executor but CUDA is not available!"
)
# Everything in the graph must support nvfuser
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
continue
if (
node.op == "call_function"
and getattr(node.target, "impl_nvfuser", None) is None
):
raise ValueError(
"All call_function nodes in the graph must support nvfuser. "
f"Node {node} with target {node.target} does not support nvfuser"
)
graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
assert len(graph_input_nodes) == len(
nv_args_templates
), "Number of placeholder nodes in the graph must match number of args"
assert len(nv_args_templates) > 0, "There must be at least one argument"
assert (
len(call_function_nodes) > 0
), "Graph must contain at least one call_function node"
assert not _any_get_attr_used(
call_function_nodes
), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
# Checking output dtypes
output_node = next(filter(lambda n: n.op == "output", gm.graph.nodes))
orig_flat_out, _ = tree_flatten(output_node.args[0])
fusion, fd = create_fusion_definition()
with fd:
def _to_nvfuser_constant(arg):
if isinstance(arg, Number):
return fd.define_constant(arg)
else:
return arg
class FusionInterpreter(torch.fx.Interpreter):
def run_node(self, node):
# Squeeze requires original shape of args[0]
if node.target in (
torch.ops.nvprims.squeeze,
torch.ops.nvprims.squeeze.default,
):
original_shape = list(node.args[0].meta["tensor_meta"].shape)
assert len(node.args) == 2
args, kwargs = self.fetch_args_kwargs_from_env(node)
args = args[:1] + (original_shape,) + args[1:]
return self.call_function(node.target, args, node.kwargs)
if node.target in (
torch.ops.nvprims.native_batch_norm,
torch.ops.nvprims.native_batch_norm.default,
):
args, kwargs = self.fetch_args_kwargs_from_env(node)
assert len(args) == 8
training = args[5]
args6_end = tuple(_to_nvfuser_constant(arg) for arg in args[6:])
args = args[:5] + (training,) + args6_end
return node.target.impl_nvfuser(fd, *args, **kwargs)
return super().run_node(node)
def call_function(self, target, args, kwargs):
# This handles tuple unpacking
if target == operator.getitem:
assert isinstance(args[0], tuple)
return target(*args, **kwargs)
args = tuple(_to_nvfuser_constant(arg) for arg in args)
target = target.impl_nvfuser
args = (fd,) + args
return target(*args, **kwargs)
def output(self, target, args, kwargs):
flat_out, unflatten_spec = tree_flatten(args[0])
for o, orig_o in zip(flat_out, orig_flat_out):
# casting outputs to the original data type
# ensures outputs produced by fusion would always agree with original GraphModule
out_dtype = _torch_dtype_to_nvfuser_dtype_map.get(orig_o.meta["tensor_meta"].dtype) # type: ignore[union-attr]
assert isinstance(
o, Tensor
), "output from codegen has to be tensor type"
fd.add_output(fd.ops.cast(o, dtype=out_dtype))
return args[0]
def templates_to_nvfuser_inputs(arg):
if isinstance(arg, nvFuserTensorTemplate):
x = fd.define_tensor(
arg.symbolic_shape, arg.contiguity, arg.dtype, arg.is_cpu
)
return x
elif isinstance(arg, nvFuserScalarTemplate):
x = fd.define_scalar(arg.dtype)
return x
else:
return arg
# Transforms graph to call nvfuser lowerings
nv_args = tuple(
templates_to_nvfuser_inputs(nv_arg) for nv_arg in nv_args_templates
)
out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)
return fusion, unflatten_spec
def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
flat_args, _ = tree_flatten(args)
# check for cuda only fusion
if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all( # type: ignore[attr-defined]
(
not isinstance(arg, torch.Tensor)
or (arg.is_cpu and arg.ndim == 0) # type: ignore[attr-defined]
or arg.is_cuda # type: ignore[attr-defined]
)
for arg in flat_args
):
# Construction of the fusion is expensive and cached based on the GraphModule
# and symbolic nvFuser args.
nv_template_args = to_nvfuser_template_args(flat_args)
use_cache = executor_parameters.get(
"use_python_fusion_cache",
DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
)
if use_cache:
fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args) # type: ignore[misc]
else:
fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args) # type: ignore[misc]
# Inputs to fusion.execute correspond to the same template/symbolic inputs
# marked with `define_tensor/scalar`
concrete_fusion_inputs = tuple(
arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
)
if get_nvprim_dump_nvtx():
torch.cuda.nvtx.range_push(
"fusion: {}, graph: {}".format(
fusion.id(),
str(
[
{
"op": n.op,
"name": n.name,
"args": n.args,
"kwargs": n.kwargs,
}
for n in gm.graph.nodes
]
),
)
)
warn("nvfuser integration in primTorch is deprecated")
result = tree_unflatten(
fusion.execute(concrete_fusion_inputs), # type: ignore[has-type]
unflatten_spec, # type: ignore[has-type]
)
if get_nvprim_dump_nvtx():
torch.cuda.nvtx.range_pop()
return result
else:
warn(
"nvfuser_executor is executed with non-cuda args, fallback to aten executor"
)
return gm.forward(*args)
class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
# special case to stop lowering to nvprim when converting to an unsupported type
if (
node.op == "call_function"
and node.target == torch.ops.nvprims.convert_element_type.default
):
return (
_torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
and _torch_dtype_to_nvfuser_dtype_map.get(
node.args[0].meta["tensor_meta"].dtype # type: ignore[union-attr]
)
is not None
)
return node.op == "call_function" and (
getattr(node.target, "impl_nvfuser", None) is not None
or node.target == operator.getitem
)
class PartitionedInterpreter(torch.fx.Interpreter):
def call_module(self, target, args, kwargs):
assert isinstance(target, str)
assert len(kwargs) == 0
submod = self.fetch_attr(target)
# CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
if target.startswith("fused_"):
return nvfuser_execute(submod, *args)
else:
return super().call_module(target, args, kwargs)
class NvfuserGraphModule(torch.nn.Module):
def __init__(self, gm, use_python_fusion_cache):
super().__init__()
self.gm = gm
self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}
def __call__(self, *args):
return nvfuser_execute(
self.gm, *args, executor_parameters=self.executor_parameters
)
# A set of operators that are supported by nvFuser
# but should not form a fusion group solely on their own
_non_compute_ops = [
"torch.ops." + str(getattr(torch.ops.nvprims, prim).default)
for prim in dir(torch.ops.nvprims)
if isinstance(getattr(torch.ops.nvprims, prim), torch._ops.OpOverloadPacket)
and getattr(torch.ops.nvprims, prim).return_type
== torch._prims_common.RETURN_TYPE.VIEW
]
_allowed_single_node_partition_ops = [
"torch.ops.nvprims.native_batch_norm.default",
"torch.ops.nvprims.var_mean.default",
"torch.ops.nvprims.var_mean.main",
]
def _remove_empty_like_fill(gm: GraphModule):
# Remove empty_like + fill nodes that prevent lowering to nvprims
# This is a workaround for nonoptimal traces of C++ code `(1 - tensor)`
# https://github.com/pytorch/pytorch/issues/86612
def pattern(scalar, tensor):
# pattern for C++ trace of `scalar - tensor`. We are looking for the
# pattern of aten and nvprims.sub specifically because we want to remove
# the empty_like + fill nodes after lowering of AOT Autograd trace to
# nvprims In the future, nvFuser might support fill, and empty_like and
# this workaround can be removed.
empty_like = torch.ops.aten.empty_like.default(
tensor, memory_format=torch.preserve_format
)
fill = torch.ops.aten.fill.Scalar(empty_like, scalar)
sub = torch.ops.nvprims.sub.default(fill, tensor)
return sub
def replacement(scalar, tensor):
return torch.ops.nvprims.sub.default(scalar, tensor)
torch.fx.replace_pattern(gm, pattern, replacement)
return gm
# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024) # type: ignore[arg-type]
def maybe_partition_graph(
gm: GraphModule, allow_single_op_fusion: bool, use_python_fusion_cache: bool
):
gm = _remove_empty_like_fill(gm)
supported_ops = NvfuserPrimOperatorSupport()
call_function_nodes = list(
filter(lambda n: n.op == "call_function", gm.graph.nodes)
)
# the graph is partitioned only if at least one node is not supported by nvFuser
any_unsupported = any(
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
)
any_unsupported |= len(call_function_nodes) == 0
# When there are constant tensors in the graph, we can't partition it
# because deepcopy fails. Here we just return the original graph to be
# executed by eager mode
# https://github.com/pytorch/pytorch/issues/84415
if (
_any_get_attr_used(call_function_nodes)
or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
):
return gm, True
if any_unsupported:
# CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
gm = deepcopy(gm)
partitioner = CapabilityBasedPartitioner(
gm,
supported_ops,
allows_single_node_partition=allow_single_op_fusion,
non_compute_ops=_non_compute_ops,
allowed_single_node_partition_ops=_allowed_single_node_partition_ops,
)
partitions = partitioner.propose_partitions()
partitioner.remove_bookend_non_compute_ops(partitions)
if len(partitions) == 0:
warn(
"No partition found for the graph. "
+ "This is likely because the graph is not supported by nvFuser. "
+ "Please use the eager ATen mode to execute the graph.",
category=RuntimeWarning,
)
partitioned_graph = partitioner.fuse_partitions(partitions)
# Replacing graph's fused submodules with a wrapper module with
# __call__() method that calls nvfuser_execute.
# This avoids the need to call the interpreter on the graph
for node in partitioned_graph.graph.nodes:
# TODO: use a better way to identify fused submodule
if node.op == "call_module" and "fused_" in node.name:
nvfuser_submodule = getattr(partitioned_graph, node.name)
partitioned_graph.delete_submodule(node.target)
gm.add_submodule(
node.target,
NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache),
)
# Go through the graph and replace all the nodes that were converted to
# nvprims but won't be sent to nvFuser with a call to PyTorch's eager
# mode. This is necessary because torch.ops.* have higher overhead than
# calling the eager mode directly.
for node in partitioned_graph.graph.nodes:
if node.op == "call_function" and str(node.target).startswith("nvprims."):
if getattr(node.target, "impl_aten", None) is not None:
node.target = node.target.impl_aten
partitioned_graph.graph.eliminate_dead_code()
partitioned_graph.recompile()
return partitioned_graph, any_unsupported
else:
return gm, any_unsupported
class NVTXInterpreter(torch.fx.Interpreter):
def run_node(self, n):
torch.cuda.nvtx.range_push(
f"name: {n.name}, args: {n.args}, op: {n.op}, kwargs: {n.kwargs}"
)
result = super().run_node(n)
torch.cuda.nvtx.range_pop()
return result
def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None):
executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
# maybe_partition_graph function is cached so we can't use non-hashable arguments
allow_single_op_fusion = executor_parameters.get(
"allow_single_op_fusion",
DEFAULT_NVFUSER_PYTHON_CONFIG["allow_single_op_fusion"],
)
use_python_fusion_cache = executor_parameters.get(
"use_python_fusion_cache",
DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
)
# When possible it's better to use nvfuser_execute directly
# because it avoids GraphModule's overhead
gm, is_partitioned = maybe_partition_graph(
gm,
allow_single_op_fusion=allow_single_op_fusion,
use_python_fusion_cache=use_python_fusion_cache,
)
if is_partitioned:
if get_nvprim_dump_nvtx():
return NVTXInterpreter(gm).run(*args)
else:
return gm(*args)
else:
return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)

View File

@ -1,840 +0,0 @@
# Module for defining "primitive" operations executable by the nvFuser. This
# list exists to decouple main set of primitives from the ones that provide a
# lowering of the op to nvFusers Python interface. Mostly torch.ops.nvprims is
# a subset of the primitives in torch.ops.prims, but some additional primitives
# can be added in the future for the corresponding higher-level torch/aten
# functions.
from typing import Any, Dict, Optional, Tuple
import torch
import torch._prims_common as utils
from torch._prims_common import (
DimsSequenceType,
elementwise_dtypes,
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
make_contiguous_strides_for,
NumberType,
ShapeType,
TensorLikeType,
)
from torch._prims_common.wrappers import (
_maybe_convert_to_dtype,
backwards_not_supported,
elementwise_type_promotion_wrapper,
)
nvprim_namespace = "nvprims"
nvprim = torch.library.Library(nvprim_namespace, "DEF")
nvprim_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
nvprim_implicit_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
)
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
nvprim_names = [
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"clone",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"real",
"reciprocal",
"neg",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"transpose",
"trunc",
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
"squeeze",
"view_of",
"broadcast_in_dim",
"where",
"convert_element_type",
"sum",
"var",
"amax",
"amin",
]
_nvfuser_impls: Dict[str, Any] = {}
_nvfuser_unary_ops = {
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"reciprocal",
"neg",
"real",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"trunc",
}
def _assert_nvfuser_op_exists(fname: str):
try:
try:
from nvfuser import ( # type: ignore[import, attr-defined]
FusionDefinition as fd,
)
except ImportError:
from nvfuser._C import FusionDefinition as fd # type: ignore[import]
assert getattr(fd.Operators, fname)
except ImportError:
# Not all PyTorch builds have nvfuser
pass
for fname in _nvfuser_unary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a):
return fd.ops.{fname}(a) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_binary_ops = {
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
}
for fname in _nvfuser_binary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b):
return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_ternary_ops = {
"where",
}
for fname in _nvfuser_ternary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b, c):
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
):
"""
if weight is None:
weight = fd.define_null_tensor()
if bias is None:
bias = fd.define_null_tensor()
if running_mean is None:
running_mean = fd.define_null_tensor()
if running_var is None:
running_var = fd.define_null_tensor()
"""
return fd.ops.batch_norm(
input,
weight,
bias,
running_mean,
running_var,
momentum,
eps,
training,
)
def _broadcast_in_dim_nvfuser(
fd: Any,
a: TensorLikeType,
shape: ShapeType,
broadcast_dimensions: ShapeType,
):
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
def _transpose_nvfuser(fd, a, dims):
return fd.ops.permute(a, dims) # type: ignore[attr-defined]
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
for idx in sorted(dimensions, reverse=True):
a = fd.ops.squeeze(a, a_shape, idx)
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
return a
def _view_of_nvfuser(fd, a):
return fd.ops.set(a)
def _view_nvfuser(
fd,
a,
a_shape,
new_shape,
):
try:
return fd.ops.view(a, a_shape, new_shape)
except AttributeError:
return fd.ops.reshape(a, a_shape, new_shape)
def _sum_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]
output_dtype = DataType.Null
return fd.ops.sum(a, dims, keep_dims, output_dtype)
def _var_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
*,
correction: float,
):
keep_dims = False
return fd.ops.var(a, dims, correction, keep_dims)
def _var_mean_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: float,
):
# Unbiased arg shouldn't be set when this function is called
assert unbiased is None
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
# keepdim is handled by the reference implementation
keepdim = False
return fd.ops.var_mean(a, dims, correction, keepdim)
def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
return fd.ops.rand_like(a)
def _amax_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.max(a, dims, keep_dims)
def _amin_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.min(a, dims, keep_dims)
def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
return fd.ops.set(input)
def _full_nvfuser(
fd: Any,
shape: ShapeType,
fill_value: NumberType,
*,
dtype: Optional[torch.dtype] = None,
layout: Optional[torch.layout] = None,
device: Optional[torch.device] = None,
pin_memory: bool = False,
requires_grad: bool = False,
):
assert device != torch.device("cpu")
assert layout is None or layout is torch.strided
assert pin_memory is False
assert requires_grad is False
dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.full(shape, fill_value, nvfuser_dtype)
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
_nvfuser_impls["clone"] = _clone_nvfuser
_nvfuser_impls["transpose"] = _transpose_nvfuser
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
_nvfuser_impls["view_of"] = _view_of_nvfuser
_nvfuser_impls["view"] = _view_nvfuser
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
_nvfuser_impls["sum"] = _sum_nvfuser
_nvfuser_impls["var"] = _var_nvfuser
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
_nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
_nvfuser_impls["full"] = _full_nvfuser
def register_full():
name = "full"
nvprim.define(
"full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
+ "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
)
def _meta_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
):
strides = make_contiguous_strides_for(size)
return torch._prims.TensorMeta(
None,
shape=size,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
size,
fill_value,
*,
out=None,
dtype=None,
layout=None,
device=None,
pin_memory=False,
requires_grad=False,
):
return torch.full(
size,
fill_value,
out=out,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_impl)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Create a tensor with given size and filled with value"
p.impl_nvfuser = _nvfuser_impls["full"]
p.is_recomputable = _nvfuser_is_recomputable["full"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
# functorch.compile.min_cut_rematerialization_partition accepts a list of
# operators that can be recomputed in the backward pass. This list is used to
# determine which operators can be recomputed. If an operator is not in this
# list, it will not be recomputed.
_nvfuser_is_recomputable: Dict[str, bool] = {
# Reductions are not allowed to be recomputed
"amax": False,
"amin": False,
"sum": False,
"var": False,
"var_mean": False,
# Normalizations are not allowed to be recomputed
"native_batch_norm": False,
# Random ops are not allowed to be recomputed
"rand_like": False,
# Everything else is allowed to be recomputed
"abs": True,
"acos": True,
"add": True,
"asin": True,
"atan": True,
"atan2": True,
"atanh": True,
"bitwise_and": True,
"bitwise_not": True,
"bitwise_or": True,
"bitwise_xor": True,
"broadcast_in_dim": True,
"ceil": True,
"clone": True,
"convert_element_type": True,
"cos": True,
"cosh": True,
"div": True,
"eq": True,
"erf": True,
"erfc": True,
"exp": True,
"expm1": True,
"floor": True,
"fmod": True,
"full": True,
"ge": True,
"gt": True,
"imag": True,
"isfinite": True,
"le": True,
"lgamma": True,
"log": True,
"log10": True,
"log1p": True,
"log2": True,
"lt": True,
"mul": True,
"ne": True,
"neg": True,
"pow": True,
"real": True,
"reciprocal": True,
"remainder": True,
"round": True,
"rsqrt": True,
"sign": True,
"sin": True,
"sinh": True,
"sqrt": True,
"squeeze": True,
"sub": True,
"tan": True,
"tanh": True,
"transpose": True,
"trunc": True,
"view": True,
"view_of": True,
"where": True,
}
def register_native_batch_norm():
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
name = "native_batch_norm"
nvprim.define(
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
+ "bool training, float momentum, float eps)"
+ " -> (Tensor, Tensor, Tensor)"
)
def _prim_impl(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_impl.impl(name, _prim_impl)
prim_packet = torch._ops.ops.nvprims.native_batch_norm
prim = prim_packet.default
def _native_batch_norm_ref(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
training: bool,
momentum: float,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if torch._prims_common.is_complex_dtype(input.dtype):
raise NotImplementedError("Complex tensors are not supported")
# note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
result_dtype = input.dtype
computation_dtype, _ = elementwise_dtypes(
input,
weight,
bias,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
input_ = _maybe_convert_to_dtype(input, computation_dtype)
output, mean, rstd = prim(
input_, weight, bias, running_mean, running_var, training, momentum, eps
)
output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
return (output_, mean, rstd) # type: ignore[return-value]
def _native_batch_norm_autograd(
input: torch.Tensor,
weight: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
training: bool,
momentum: float,
eps: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# This wrapper is needed to convert prims calls inside
# _native_batch_norm_ref to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_native_batch_norm_ref)(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes batch normalization."
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_rand_like():
name = "rand_like"
nvprim.define(
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
)
def _meta_rand_like(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
strides = make_contiguous_strides_for(self.shape)
return torch._prims.TensorMeta(
self,
shape=self.shape,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
return torch.rand_like(
self,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_rand_like)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Computes rand_like"
p.impl_nvfuser = _nvfuser_impls["rand_like"]
p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_var_mean():
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
name = "var_mean.main"
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
nvprim.define(
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, float? correction=None)"
+ " -> (Tensor, Tensor)"
)
# This function is used for device="meta" Tensors.
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
if torch._prims_common.is_complex_dtype(inp.dtype):
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
else:
output_dtype = inp.dtype
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
if keepdim:
output_shape = [
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
]
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
var = torch._ops.ops.nvprims.broadcast_in_dim(
var, output_shape, broadcast_dims
)
mean = torch._ops.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
return (var, mean)
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_var_mean)
prim_packet = torch._ops.ops.nvprims.var_mean
prim = prim_packet.main
def _unbiased_overload_impl(inp, unbiased):
return prim(inp, dim=None, unbiased=unbiased)
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
dim = torch._prims_common.reduction_dims(a.shape, dim)
# For complex tensors eager computes the variance as the sum of variances of
# the real and imaginary parts
# TODO: Creating a complex tensor from real and imaginary parts is not supported
if torch._prims_common.is_complex_dtype(a.dtype):
raise NotImplementedError("Complex tensors are not supported")
var_mean = prim(a, dim, correction=correction)
if keepdim:
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
var, mean = var_mean
var = torch._ops.ops.nvprims.broadcast_in_dim(
var, output_shape, broadcast_dims
)
mean = torch._ops.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
var_mean = (var, mean)
return var_mean
def _var_mean_autograd(
a, dim=None, unbiased=None, keepdim=False, *, correction=None
):
# This wrapper is needed to convert prims calls inside
# elementwise_type_promotion_wrapper to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_var_mean_ref)(
a, dim, unbiased, keepdim, correction=correction
)
nvprim_autograd_impl.impl(name, _var_mean_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
p.impl_nvfuser = _nvfuser_impls["var_mean"]
p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def _nvprims_view_impl_aten(a, original_shape, new_shape):
return a.reshape(new_shape)
def register_view():
"""This function is used to register the view function in torch.ops.view module."""
# View is implemented as a decomposition into prims.split_dim,
# prims.collapse_dim, and prims.reshape, but we would like to intercept
# non-decomposed view for now
name = "view"
nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(a, original_shape, new_shape):
return a.reshape(new_shape)
nvprim_impl.impl(name, _prim_impl)
prim_packet = torch._ops.ops.nvprims.view
prim = prim_packet.default
def _view_no_original_shape_overload_impl(a, shape):
if list(a.shape) == list(shape):
return torch.ops.nvprims.view_of(a)
return torch.ops.nvprims.view.default(a, a.shape, shape)
nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
p.impl_nvfuser = _nvfuser_impls["view"]
p.is_recomputable = _nvfuser_is_recomputable["view"]
p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
p.impl_aten = _nvprims_view_impl_aten
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
register_view()
register_native_batch_norm()
register_rand_like()
register_full()
for name in nvprim_names:
main_prim = getattr(torch._ops.ops.prims, name)
nvprim.define(main_prim.schema)
nvprim_impl.impl(name, main_prim.prim_impl)
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
prim_packet = getattr(torch._ops.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = main_prim.__doc__
p.impl_nvfuser = _nvfuser_impls[name]
p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
p.return_type = main_prim.return_type # type: ignore[attr-defined]
p.impl_aten = main_prim.impl_aten

View File

@ -25,39 +25,6 @@ import sympy
import torch
from torch import sym_float, sym_int, sym_max
try:
try:
from nvfuser import DataType # type: ignore[import, attr-defined]
except ImportError:
from nvfuser._C import DataType # type: ignore[import]
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.uint8: DataType.Int32,
torch.bool: DataType.Bool,
# Python scalars
complex: DataType.ComplexDouble,
float: DataType.Double,
int: DataType.Int,
bool: DataType.Bool,
}
except ImportError:
_torch_dtype_to_nvfuser_dtype_map = {}
def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]):
"""
Translates from torch.dtype to nvFuser's DataType enum
"""
return _torch_dtype_to_nvfuser_dtype_map[dtype]
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType = Union[List[int], Tuple[int, ...]]

File diff suppressed because it is too large Load Diff

View File

@ -50,14 +50,12 @@ class SpectralFuncPythonRefInfo(SpectralFuncInfo):
op=None, # the function variant of the operation, populated as torch.<name> if None
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant="",
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant, op_db=op_db
)
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, SpectralFuncInfo)
inherited = self.torch_opinfo._original_spectral_func_args
@ -652,103 +650,83 @@ python_ref_db: List[OpInfo] = [
SpectralFuncPythonRefInfo(
"_refs.fft.fft",
torch_opinfo_name="fft.fft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifft",
torch_opinfo_name="fft.ifft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft",
torch_opinfo_name="fft.rfft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfft",
torch_opinfo_name="fft.irfft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft",
torch_opinfo_name="fft.hfft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft",
torch_opinfo_name="fft.ihfft",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.fftn",
torch_opinfo_name="fft.fftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifftn",
torch_opinfo_name="fft.ifftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfftn",
torch_opinfo_name="fft.rfftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfftn",
torch_opinfo_name="fft.irfftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfftn",
torch_opinfo_name="fft.hfftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfftn",
torch_opinfo_name="fft.ihfftn",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.fft2",
torch_opinfo_name="fft.fft2",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ifft2",
torch_opinfo_name="fft.ifft2",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.rfft2",
torch_opinfo_name="fft.rfft2",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.irfft2",
torch_opinfo_name="fft.irfft2",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.hfft2",
torch_opinfo_name="fft.hfft2",
supports_nvfuser=False,
),
SpectralFuncPythonRefInfo(
"_refs.fft.ihfft2",
torch_opinfo_name="fft.ihfft2",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.fft.fftshift",
op_db=op_db,
torch_opinfo_name="fft.fftshift",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.fft.ifftshift",
op_db=op_db,
torch_opinfo_name="fft.ifftshift",
supports_nvfuser=False,
),
]

View File

@ -2406,22 +2406,18 @@ python_ref_db: List[OpInfo] = [
"_refs.linalg.diagonal",
torch_opinfo_name="linalg.diagonal",
supports_out=False,
supports_nvfuser=False,
op_db=op_db,
),
ReductionPythonRefInfo(
"_refs.linalg.vector_norm",
torch_opinfo_name="linalg.vector_norm",
supports_out=True,
supports_nvfuser=False, # clone_default
op_db=op_db,
),
PythonRefInfo(
"_refs.linalg.matrix_norm",
torch_opinfo_name="linalg.matrix_norm",
supports_out=True,
# Uses svdvals which does not support nvfuser
supports_nvfuser=False,
# Uses vector_norm inside and vector_norm is affected by
# https://github.com/pytorch/pytorch/issues/77216
validate_view_consistency=False,
@ -2431,8 +2427,6 @@ python_ref_db: List[OpInfo] = [
"_refs.linalg.norm",
torch_opinfo_name="linalg.norm",
supports_out=True,
# Uses svdvals which does not support nvfuser
supports_nvfuser=False,
# Uses vector_norm inside and vector_norm is affected by
# https://github.com/pytorch/pytorch/issues/77216
validate_view_consistency=False,
@ -2442,14 +2436,12 @@ python_ref_db: List[OpInfo] = [
"_refs.linalg.svd",
torch_opinfo_name="linalg.svd",
supports_out=True,
supports_nvfuser=False,
op_db=op_db,
),
PythonRefInfo(
"_refs.linalg.svdvals",
torch_opinfo_name="linalg.svdvals",
supports_out=True,
supports_nvfuser=False,
op_db=op_db,
),
]

View File

@ -690,67 +690,56 @@ python_ref_db: List[OpInfo] = [
ElementwiseUnaryPythonRefInfo(
"_refs.special.bessel_j0",
torch_opinfo_name="special.bessel_j0",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.bessel_j1",
torch_opinfo_name="special.bessel_j1",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.entr",
torch_opinfo_name="special.entr",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.erfcx",
torch_opinfo_name="special.erfcx",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i0e",
torch_opinfo_name="special.i0e",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i1",
torch_opinfo_name="special.i1",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.i1e",
torch_opinfo_name="special.i1e",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.log_ndtr",
torch_opinfo_name="special.log_ndtr",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.ndtr",
torch_opinfo_name="special.ndtr",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.ndtri",
torch_opinfo_name="special.ndtri",
supports_nvfuser=False,
op_db=op_db,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.spherical_bessel_j0",
torch_opinfo_name="special.spherical_bessel_j0",
supports_nvfuser=False,
op_db=op_db,
),
#
@ -760,7 +749,6 @@ python_ref_db: List[OpInfo] = [
"_refs.special.zeta",
torch_opinfo_name="special.zeta",
supports_one_python_scalar=True,
supports_nvfuser=False,
op_db=op_db,
skips=(
# Reference reference_inputs nans and infs on cuda and nan, inf, 0., -inf for cpu

View File

@ -100,7 +100,6 @@ class PythonRefInfo(OpInfo):
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
validate_view_consistency=True,
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
@ -109,7 +108,6 @@ class PythonRefInfo(OpInfo):
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
)
self.validate_view_consistency = validate_view_consistency
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, OpInfo)
inherited = self.torch_opinfo._original_opinfo_args
@ -130,7 +128,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
op_db=None, # The database of opinfos to search for the parent opinfo
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
@ -138,7 +135,6 @@ class ReductionPythonRefInfo(ReductionOpInfo):
self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
)
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, ReductionOpInfo)
inherited = self.torch_opinfo._original_reduction_args
@ -164,7 +160,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
validate_view_consistency=True,
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
@ -173,7 +168,6 @@ class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
)
self.validate_view_consistency = validate_view_consistency
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
inherited = self.torch_opinfo._original_unary_ufunc_args
@ -195,7 +189,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
op_db=None, # The database of opinfos to search for the parent opinfo
torch_opinfo_name, # the string name of the corresponding torch opinfo
torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
supports_nvfuser=True,
**kwargs,
): # additional kwargs override kwargs inherited from the torch opinfo
self.torch_opinfo_name = torch_opinfo_name
@ -203,7 +196,6 @@ class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
self.torch_opinfo = _find_referenced_opinfo(
torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
)
self.supports_nvfuser = supports_nvfuser
assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
inherited = self.torch_opinfo._original_binary_ufunc_args