mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
1. Patching nvfuser_execute to take aten nvprim fallback when no cuda tensors are provided as inputs 2. Extending support of nvfuser python API on cpu scalar tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84411 Approved by: https://github.com/ngimel, https://github.com/kevinstephano, https://github.com/IvanYashchuk
728 lines
27 KiB
Python
728 lines
27 KiB
Python
# Owner(s): ["module: primTorch"]
|
|
|
|
from functools import partial
|
|
from itertools import product
|
|
import warnings
|
|
from warnings import catch_warnings
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
onlyCUDA,
|
|
skipCUDAIfRocm,
|
|
dtypes,
|
|
)
|
|
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
|
|
import torch._prims as prims
|
|
from torch._prims.executor import make_traced
|
|
import torch._refs as refs
|
|
|
|
|
|
if TEST_SCIPY:
|
|
import scipy.special
|
|
|
|
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
|
|
|
|
class TestPrims(TestCase):
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32)
|
|
def test_broadcast_in_dim(self, device, dtype):
|
|
def _wrapper(a, b, broadcast_dimensions):
|
|
return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
|
|
|
|
traced = make_traced(_wrapper)
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for executor in ('aten', 'strictly_nvfuser'):
|
|
fn = partial(traced, executor=executor)
|
|
# Same shape
|
|
shape = (5, 5)
|
|
a = make_arg(shape)
|
|
b = make_arg(shape, low=0.0, high=0.0)
|
|
result = fn(a, b, (0, 1))
|
|
|
|
self.assertEqual(result.shape, a.shape)
|
|
self.assertTrue(result.is_contiguous)
|
|
self.assertEqual(a, result)
|
|
|
|
# Error input: reordering dims
|
|
with self.assertRaises(Exception):
|
|
result = fn(a, b, (1, 0))
|
|
|
|
# Adding outermost dimensions
|
|
a = make_arg((5, 5))
|
|
b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
|
|
result = fn(a, b, (2, 3))
|
|
|
|
self.assertEqual(result.shape, b.shape)
|
|
self.assertEqual(a.broadcast_to(b.shape), result)
|
|
|
|
# Expands
|
|
a = make_arg((1, 5, 1))
|
|
b = make_arg((3, 5, 7), low=0.0, high=0.0)
|
|
result = fn(a, b, (0, 1, 2))
|
|
|
|
self.assertEqual(result.shape, b.shape)
|
|
self.assertEqual(a.expand_as(result), result)
|
|
|
|
# Unsqueezes
|
|
a = make_arg((1, 2, 3))
|
|
b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
|
|
result = fn(a, b, (0, 1, 3))
|
|
|
|
self.assertEqual(result.shape, b.shape)
|
|
self.assertEqual(a.unsqueeze(2), result)
|
|
|
|
# FIXME: This test exposes an issue in nvfuser
|
|
# Adds outermost, expands, and unsqueezes
|
|
"""
|
|
a = make_arg((1, 2, 3))
|
|
b = make_arg((4, 1, 7, 2, 3, 3), low=0.0, high=0.0)
|
|
result = fn(a, b, (1, 3, 4))
|
|
|
|
self.assertEqual(result.shape, b.shape)
|
|
a.unsqueeze_(3)
|
|
a.unsqueeze_(1)
|
|
a.unsqueeze_(0)
|
|
self.assertEqual(a.expand_as(result), result)
|
|
"""
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32)
|
|
def test_broadcast_in_dim_sum(self, device, dtype):
|
|
def _wrapper(a):
|
|
a_sum = prims.sum(a, [0, 1])
|
|
a_bc = prims.broadcast_in_dim(a_sum, [], [])
|
|
return a_bc
|
|
|
|
traced = make_traced(_wrapper)
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for executor in ('aten', 'strictly_nvfuser'):
|
|
fn = partial(traced, executor=executor)
|
|
shape = (5, 5)
|
|
a = make_arg(shape)
|
|
result = fn(a)
|
|
|
|
self.assertEqual(result.shape, ())
|
|
self.assertTrue(result.is_contiguous)
|
|
self.assertEqual(_wrapper(a), result)
|
|
|
|
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
|
|
@dtypes(torch.float64, torch.long)
|
|
def test_cbrt_prim(self, device, dtype):
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
|
|
shapes = [(), (0,), (1,), (5,)]
|
|
|
|
try:
|
|
# Sets the default dtype to NumPy's default dtype of double
|
|
cur_default = torch.get_default_dtype()
|
|
torch.set_default_dtype(torch.double)
|
|
|
|
# Tested here, as this OP is not currently exposed or tested in ATen
|
|
for b, s in product(batches, shapes):
|
|
x = make_arg(b + s)
|
|
y = prims.cbrt(x)
|
|
|
|
x_np = x.cpu().numpy()
|
|
y_np = scipy.special.cbrt(x_np)
|
|
|
|
self.assertEqual(y, y_np, exact_device=False)
|
|
finally:
|
|
torch.set_default_dtype(cur_default)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_impl_is_used(self, device):
|
|
# This test is to ensure that when the nvfuser implementation exists it is used
|
|
# Assuming one-to-one mapping between prims and nvfuser implementations
|
|
# This test is not intended to test the correctness of the nvfuser implementation
|
|
from torch._C._nvfuser import FusionDefinition as fd
|
|
|
|
prim_nvfuser_ops = set(torch._prims.__all__).intersection(dir(fd.ops))
|
|
ops_without_nvfuser_impl = {
|
|
name
|
|
for name in prim_nvfuser_ops
|
|
if getattr(torch.ops.nvprims, name, None) is None
|
|
}
|
|
assert (
|
|
len(ops_without_nvfuser_impl) == 0
|
|
), (f"The following prims do not have 'impl_nvfuser' defined: {ops_without_nvfuser_impl} ",
|
|
"while there exists nvfuser implementations for them.")
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_empty_fusion(self, device):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.executor import execute
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
|
|
def func(a, b, c):
|
|
return (a, b, c)
|
|
|
|
gm = make_fx(func)(a, a, a)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"):
|
|
execute(gm, a, a, a, executor="strictly_nvfuser")
|
|
|
|
# Should pass with partitioned executor
|
|
out = execute(gm, a, a, a, executor="nvfuser")
|
|
self.assertEqual(out, (a, a, a))
|
|
|
|
@skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_no_args(self, device):
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.executor import execute
|
|
from torch._prims.nvfuser_executor import make_nvfuser_fusion
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
|
|
def func():
|
|
return torch.sigmoid(a)
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(func)()
|
|
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
execute(gm, executor="strictly_nvfuser")
|
|
# fusion execute with no cuda input is handled by nvprim aten fallback
|
|
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
|
|
|
with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
|
|
make_nvfuser_fusion(gm)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
|
|
execute(gm, a, executor="strictly_nvfuser")
|
|
|
|
# Should pass with partitioned executor
|
|
out = execute(gm, executor="nvfuser")
|
|
self.assertEqual(out, func())
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_constant_tensors(self, device):
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.executor import execute
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
b = torch.randn(3, 3, device=device)
|
|
|
|
def func(b):
|
|
return a + b
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(func)(b)
|
|
|
|
with self.assertRaisesRegex(AssertionError, "not supported yet"):
|
|
execute(gm, b, executor="strictly_nvfuser")
|
|
|
|
# Should pass with partitioned executor
|
|
out = execute(gm, b, executor="nvfuser")
|
|
self.assertEqual(out, gm(b))
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_executor_cached_noncontiguous(self, device):
|
|
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsMode
|
|
from torch._prims.executor import execute
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
|
|
def func(a):
|
|
return torch.sigmoid(a)
|
|
|
|
with TorchRefsMode():
|
|
gm = make_fx(func)(a)
|
|
|
|
# First run to create the cache
|
|
execute(gm, a, executor="nvfuser")
|
|
|
|
# a.mT is noncontiguous, but it shouldn't affect correctness
|
|
expected = execute(gm, a.mT, executor="aten")
|
|
actual = execute(gm, a.mT, executor="nvfuser")
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_nvfuser_capability_context(self, device):
|
|
# This test is to ensure that the torch calls are replaced with refs
|
|
# based on the nvfuser+prims capability
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
|
|
# It's assumed that digamma is not supported by nvfuser
|
|
# If it's ever supported, this test will need to be updated
|
|
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
|
|
def func(a):
|
|
return torch.digamma(a)
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(func)(a)
|
|
|
|
# Check that the torch.digamma is not replaced with torch.ops.prims.digamma
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_aten_digamma = any(
|
|
torch.ops.aten.digamma.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
includes_prims_digamma = any(
|
|
torch.ops.prims.digamma.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertTrue(includes_aten_digamma)
|
|
self.assertFalse(includes_prims_digamma)
|
|
|
|
# Check mixed case, sigmoid is replaced with refs, but digamma is not
|
|
def func(a):
|
|
return torch.sigmoid(torch.digamma(a))
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(func)(a)
|
|
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_aten_sigmoid = any(
|
|
torch.ops.aten.sigmoid.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
includes_prims_digamma = any(
|
|
torch.ops.prims.digamma.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
includes_nvprims_exp = any(
|
|
torch.ops.nvprims.exp.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertFalse(includes_aten_sigmoid)
|
|
self.assertFalse(includes_prims_digamma)
|
|
self.assertTrue(includes_nvprims_exp)
|
|
|
|
|
|
def test_aten_overload_to_prims(self, device):
|
|
# This test is to ensure that the torch.ops.aten calls are replaced with refs
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsMode
|
|
|
|
a = torch.randn(3, 3, device=device)
|
|
|
|
def func(a):
|
|
return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))
|
|
|
|
with TorchRefsMode():
|
|
gm = make_fx(func)(a)
|
|
|
|
# Check that all call_function nodes are prims
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
all_prims_namespace = all(
|
|
node.target.name.startswith("prims") for node in call_function_nodes
|
|
)
|
|
self.assertTrue(all_prims_namespace)
|
|
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_executor_partitioned(self, device):
|
|
# This test is to ensure that nvfuser partitioned executor works correctly
|
|
# It's assumed that digamma is not supported by nvfuser
|
|
# If it's ever supported, this test will need to be updated
|
|
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsMode
|
|
from torch._prims.executor import execute
|
|
|
|
a = torch.randn(3, 4, device=device)
|
|
b = torch.rand(3, 1, device=device)
|
|
c = torch.rand(3, 4, device=device)
|
|
|
|
def func(a, b, c):
|
|
aa = torch.digamma(a) # not supported by nvfuser
|
|
d = torch.add(b, c)
|
|
dd = torch.sqrt(d)
|
|
return torch.mul(aa, dd.digamma())
|
|
|
|
with TorchRefsMode():
|
|
gm = make_fx(func)(a, b, c)
|
|
|
|
expected = execute(gm, a, b, c, executor="aten")
|
|
actual = execute(gm, a, b, c, executor="nvfuser")
|
|
self.assertEqual(expected, actual)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
def test_nvfuser_executor_partitioned_no_partitions_error(self, device):
|
|
# This test is to ensure that nvfuser partitioned executor works correctly
|
|
# It's assumed that digamma is not supported by nvfuser
|
|
# If it's ever supported, this test will need to be updated
|
|
self.assertTrue(getattr(torch.ops.nvprims, "digamma", None) is None)
|
|
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsMode
|
|
from torch._prims.executor import execute
|
|
|
|
a = torch.randn(3, 4, device=device)
|
|
|
|
def func(a):
|
|
return torch.digamma(a) # not supported by nvfuser
|
|
|
|
with TorchRefsMode():
|
|
gm = make_fx(func)(a)
|
|
|
|
with catch_warnings(record=True) as w:
|
|
# Trigger warning
|
|
execute(gm, a, executor="nvfuser")
|
|
# Check warning occurs
|
|
self.assertEqual(len(w), 1)
|
|
self.assertTrue("is not supported by nvFuser" in str(w[-1].message))
|
|
|
|
def test_nvprims(self, device):
|
|
# This test is to ensure that nvfuser specific prims are exposed
|
|
# and can be traced with make_fx
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
def func(a):
|
|
return torch.ops.nvprims.add(a, a)
|
|
|
|
a = torch.randn(3, 4, device=device)
|
|
gm = make_fx(func)(a)
|
|
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(node.name == "add")
|
|
self.assertTrue(node.target == torch.ops.nvprims.add.default)
|
|
self.assertFalse(node.target == torch.ops.prims.add.default)
|
|
self.assertFalse(node.target == torch.ops.aten.add.default)
|
|
|
|
@dtypes(torch.float32, torch.float16)
|
|
def test_batch_norm_backward_nvprims(self, device, dtype):
|
|
# This test verifies that the backward pass of batch norm is correctly decomposed into nvprims
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
from torch.testing._internal.common_methods_invocations import sample_inputs_batch_norm
|
|
|
|
samples_iter = sample_inputs_batch_norm(None, device, dtype, requires_grad=True)
|
|
sample = next(samples_iter)
|
|
grad = torch.randn_like(sample.input)
|
|
|
|
def func(grad, input, weight, rm, rv, eps, train):
|
|
return torch.ops.aten.native_batch_norm_backward.default(
|
|
grad, input, weight, rm, rv, rm, rv, train, eps, [True, True, True]
|
|
)
|
|
|
|
args = sample.args
|
|
kwargs = sample.kwargs
|
|
all_args = [grad, sample.input, args[2], args[0], args[1], kwargs['eps'], kwargs['training']]
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(func)(*all_args)
|
|
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_batch_norm_backward = any(
|
|
torch.ops.aten.native_batch_norm_backward.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertFalse(includes_batch_norm_backward)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32)
|
|
@parametrize("correction", [0, 1])
|
|
def test_var(self, device, dtype, correction):
|
|
def _wrapper(a):
|
|
return prims.var(a, [0, 1], correction=correction)
|
|
|
|
traced = make_traced(_wrapper)
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for executor in ('aten', 'strictly_nvfuser'):
|
|
fn = partial(traced, executor=executor)
|
|
shape = (5, 5)
|
|
a = make_arg(shape)
|
|
result = fn(a)
|
|
|
|
self.assertEqual(result.shape, ())
|
|
self.assertTrue(result.is_contiguous)
|
|
self.assertEqual(_wrapper(a), result)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float16, torch.float32)
|
|
@parametrize("correction", [0, 1])
|
|
@parametrize("keepdim", [True, False])
|
|
def test_var_mean(self, device, dtype, correction, keepdim):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
|
|
|
|
def _wrapper(a):
|
|
return torch.var_mean(a, [0, 1], correction=correction, keepdim=keepdim)
|
|
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(_wrapper)(make_arg((5, 5)))
|
|
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_nvprims_var_mean = any(
|
|
torch.ops.nvprims.var_mean.main == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertTrue(includes_nvprims_var_mean)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32, torch.float16)
|
|
def test_cpu_tensor(self, device, dtype):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode
|
|
from torch._prims.executor import execute
|
|
|
|
def _wrapper(t0, t1, cpu_scalar):
|
|
return t0 + t1 + cpu_scalar
|
|
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
a = make_arg((12, 1))
|
|
b = make_arg((12, 12))
|
|
c = torch.tensor(0.5)
|
|
|
|
with TorchRefsNvfuserCapabilityMode():
|
|
gm = make_fx(_wrapper)(a, b, c)
|
|
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
actual = execute(gm, a, b, c, executor="nvfuser")
|
|
# cpu scalar tensor is handled by nvfuser codegen, so it shouldn't fallback
|
|
self.assertFalse(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
|
|
|
expected = execute(gm, a, b, c, executor="aten")
|
|
self.assertEqual(expected, actual)
|
|
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_aten_add = any(
|
|
torch.ops.aten.add.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertFalse(includes_aten_add)
|
|
|
|
with warnings.catch_warnings(record=True) as caught:
|
|
nvprim_aten_fallback = execute(gm, a.cpu(), b.cpu(), c, executor="nvfuser")
|
|
# cpu tensor is handled by nvprim aten fallback, assert that it's indeed in warning
|
|
self.assertTrue(any(NVPRIM_ATEN_FALLBACK_WARNING in str(w.message) for w in caught))
|
|
|
|
self.assertEqual(expected, nvprim_aten_fallback)
|
|
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32)
|
|
def test_pytree_input_output(self, device, dtype):
|
|
@make_traced
|
|
def fn(a, b_dict):
|
|
b = b_dict["b"]
|
|
d = {}
|
|
d["c"] = torch.add(a, b)
|
|
return (d, torch.add(a, d["c"]))
|
|
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
a = make_arg((5, 5))
|
|
b = make_arg((1, 5))
|
|
b_dict = {"b": b}
|
|
|
|
result_aten = fn(a, b_dict, executor="aten")
|
|
result_nvfuser = fn(a, b_dict, executor="strictly_nvfuser")
|
|
self.assertEqual(result_aten, result_nvfuser)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_memory_format_strides(self, device, dtype):
|
|
shapes = (
|
|
(),
|
|
(0,),
|
|
(1,),
|
|
(5),
|
|
(1, 0),
|
|
(1, 1),
|
|
(3, 7),
|
|
(3, 0, 2),
|
|
(1, 1, 2),
|
|
(4, 1, 1),
|
|
(7, 8, 9),
|
|
)
|
|
|
|
channels_last_shapes = (
|
|
(0, 0, 0, 0),
|
|
(1, 0, 3, 0),
|
|
(0, 2, 3, 5),
|
|
(2, 2, 2, 0),
|
|
(5, 4, 3, 2),
|
|
(8, 8, 7, 2),
|
|
(9, 1, 3, 1),
|
|
(4, 5, 8, 7)
|
|
)
|
|
|
|
channels_last_3d_shapes = (
|
|
(0, 8, 7, 9, 2),
|
|
(5, 0, 7, 9, 2),
|
|
(5, 0, 7, 9, 0),
|
|
(5, 8, 7, 9, 2),
|
|
(5, 1, 7, 9, 2),
|
|
(5, 1, 7, 9, 1),
|
|
)
|
|
|
|
pairs = (
|
|
(shapes, torch.contiguous_format),
|
|
(channels_last_shapes, torch.contiguous_format),
|
|
(channels_last_3d_shapes, torch.contiguous_format),
|
|
(channels_last_shapes, torch.channels_last),
|
|
(channels_last_3d_shapes, torch.channels_last_3d),
|
|
)
|
|
|
|
for shapes, memory_format in pairs:
|
|
for shape in shapes:
|
|
# tests empty
|
|
expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
|
|
actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
|
|
self.assertEqual(expected.stride(), actual.stride())
|
|
|
|
# tests clone
|
|
a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
|
|
expected = torch.clone(a, memory_format=memory_format)
|
|
actual = torch.clone(a, memory_format=memory_format)
|
|
self.assertEqual(expected.stride(), actual.stride())
|
|
|
|
# tests contiguous
|
|
a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
|
|
expected = a.contiguous(memory_format=memory_format)
|
|
actual = refs.contiguous(a, memory_format=memory_format)
|
|
self.assertEqual(expected.stride(), actual.stride())
|
|
|
|
@dtypes(torch.float32)
|
|
def test_reshape_view_method(self, device, dtype):
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
a = make_arg((5, 5))
|
|
new_shape = 1, 5, 1, 5
|
|
result_eager = a.reshape(*new_shape)
|
|
result_refs = refs.reshape(a, *new_shape)
|
|
self.assertEqual(result_eager, result_refs)
|
|
|
|
result_eager = a.view(*new_shape)
|
|
result_refs = refs.view(a, *new_shape)
|
|
self.assertEqual(result_eager, result_refs)
|
|
|
|
|
|
class TestPrimsBasic(TestCase):
|
|
def test_torch_ops(self):
|
|
r = make_tensor((2,), device='cpu', dtype=torch.float)
|
|
self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
|
|
|
|
r = LoggingTensor(r)
|
|
with capture_logs() as logs:
|
|
log_input("input", r)
|
|
prims.sin(r)
|
|
self.assertExpectedInline('\n'.join(logs), """\
|
|
$0 = input('input')
|
|
$1 = torch._ops.prims.sin.default($0)""")
|
|
|
|
def test_mul_complex(self):
|
|
prims.mul(torch.randn(2), 1 + 1j)
|
|
|
|
|
|
instantiate_device_type_tests(TestPrims, globals())
|
|
|
|
|
|
class TestRefs(TestCase):
|
|
@dtypes(torch.float32)
|
|
def test_constant_pad_nd_memory_format(self, device, dtype):
|
|
# Test memory format is preserved in unambiguous cases
|
|
for mf, ndim in (
|
|
(torch.channels_last, 4),
|
|
(torch.contiguous_format, 4),
|
|
(torch.channels_last_3d, 5),
|
|
(torch.contiguous_format, 5),
|
|
):
|
|
a = torch.zeros([2] * ndim).to(memory_format=mf)
|
|
res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
|
|
self.assertTrue(res.is_contiguous(memory_format=mf))
|
|
|
|
# Ambiguous cases
|
|
|
|
# is_channels_last_ and is_contiguous_, results in channels_last output
|
|
a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
|
|
self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(a.is_contiguous())
|
|
actual = refs.constant_pad_nd(a, pad=[1] * 8)
|
|
expect = torch.constant_pad_nd(a, pad=[1] * 8)
|
|
self.assertEqual(actual.stride(), expect.stride())
|
|
self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))
|
|
|
|
# is_channels_last_contiguous_ but not is_channels_last_, results in
|
|
# contiguous output
|
|
a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
|
|
self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
|
|
self.assertTrue(a.is_contiguous())
|
|
actual = refs.constant_pad_nd(a, pad=[1] * 8)
|
|
expect = torch.constant_pad_nd(a, pad=[1] * 8)
|
|
self.assertEqual(actual.stride(), expect.stride())
|
|
self.assertTrue(actual.is_contiguous())
|
|
|
|
|
|
instantiate_device_type_tests(TestRefs, globals())
|
|
|
|
|
|
class TestDecomp(TestCase):
|
|
@onlyCUDA
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float16, torch.float32)
|
|
def test_decomposition_type_promotion_nvprim_amp(self, device, dtype):
|
|
x = torch.rand(5, device=device).to(dtype)
|
|
y = torch.rand(5, device=device).to(dtype)
|
|
|
|
from torch._prims.context import TorchRefsNvfuserCapabilityMode, _is_func_unsupported_nvfuser
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
op = torch._decomp.decomposition_table.get(torch.ops.aten.leaky_relu_backward.default)
|
|
|
|
def fn0(*arg):
|
|
return _is_func_unsupported_nvfuser(mode, op, arg, {})
|
|
|
|
def fn1(x):
|
|
x = x * 2
|
|
x = x @ x
|
|
x = x * 2
|
|
return x
|
|
|
|
with TorchRefsNvfuserCapabilityMode() as mode:
|
|
self.assertFalse(fn0(x, y, 0.3, False))
|
|
|
|
# Autocast context has C++ level ATen calls that are hidden from
|
|
# TorchRefsNvfuserCapabilityMode that works only on Python level.
|
|
# The first call to make_fx records autocast C++ calls directly and
|
|
# doesn't have the chance to translate to nvprims. After the first
|
|
# call, "gm" contains explicit calls to torch.ops.aten and nothing
|
|
# is hidden, so the second call to make_fx actually translates
|
|
# recorded autocast dtype conversions to nvprims.
|
|
with torch.autocast("cuda"):
|
|
gm = make_fx(fn1)(x)
|
|
gm = make_fx(gm)(x)
|
|
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
|
includes_aten_to_copy = any(
|
|
torch.ops.aten._to_copy.default == node.target
|
|
for node in call_function_nodes
|
|
)
|
|
self.assertFalse(includes_aten_to_copy)
|
|
|
|
|
|
instantiate_device_type_tests(TestDecomp, globals())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|