Files
pytorch/test/test_python_dispatch.py
rzou 712d9882d2 Skip test responsible for causing flakiness (#145109)
Investigation is a separate issue. For now I want to get the CI back up
and running on the other tests. The problem seems to be that
IncludeDispatchKeyGuard doesn't actually reset the state, which seems
very, very wrong.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145109
Approved by: https://github.com/williamwen42
2025-01-18 00:58:25 +00:00

2783 lines
105 KiB
Python

# Owner(s): ["module: __torch_dispatch__"]
# ruff: noqa: F841
import logging
import sys
import tempfile
import unittest
from copy import deepcopy
import torch
import torch._dynamo
from torch import SymInt
from torch._C import DispatchKey, DispatchKeySet
from torch._custom_op.functional import register_functional_op
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.jiterator import _create_jit_fn
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.library import _scoped_library, fallthrough_kernel, impl, Library
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
first_sample,
IS_WINDOWS,
run_tests,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.logging_tensor import (
capture_logs,
capture_logs_with_logging_tensor_mode,
log_input,
LoggingTensor,
LoggingTensorMode,
LoggingTensorReentrant,
)
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import all_same_mode, no_dispatch
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_get_current_dispatch_mode_stack,
is_in_torch_dispatch_mode,
TorchDispatchMode,
)
from torch.utils._pytree import tree_map, tree_map_only
# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
def _identity(x):
return x
class TestDispatcherPythonBindings(TestCase):
def test_call_boxed(self) -> None:
sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
x = torch.randn(3)
y = torch._C._dispatch_call_boxed(sin, x)
self.assertEqual(y, x.sin())
class TestPythonRegistration(TestCase):
test_ns = "_test_python_registration"
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration
def test_fallback(self) -> None:
test_key = "TESTING_ONLY_GenericMode"
test_keyset = torch._C.DispatchKeySet(test_key)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
expected_op = None
expected_args = None
expected_kwargs = None
# Use this out shape to make sure the result from our fallback
# is what is returned to the user
out_shape = None
def my_fallback(op, *args, **kwargs):
# Disable our handler during checks and generating the output
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
self.assertIs(op, expected_op)
self.assertEqual(args, expected_args)
self.assertEqual(kwargs, expected_kwargs)
# Return something specific
return torch.empty(out_shape)
my_lib.fallback(my_fallback, test_key)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
# Check a factory function
expected_op = torch.ops.aten.empty.memory_format
expected_args = ((2, 2),)
# Extra kwargs to bypass issues with default args in factory functions
expected_kwargs = {
"dtype": torch.float64,
"pin_memory": False,
"device": torch.device("cpu"),
}
out_shape = (3,)
out = torch.empty(*expected_args, **expected_kwargs)
self.assertEqual(out.size(), out_shape)
# Check a regular function
expected_op = torch.ops.aten.add.Tensor
expected_args = (a, b)
expected_kwargs = {}
out_shape = (4,)
out = a + b
self.assertEqual(out.size(), out_shape)
def test_fallback_keyset(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
first_called = False
second_called = False
def first_fallback(keyset, op, *args, **kwargs):
nonlocal first_called
if second_called:
# Recursive call
first_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
else:
# Redispatch down
keyset = keyset.remove(test_key_first)
return op.redispatch(keyset, *args, **kwargs)
def second_fallback(op, *args, **kwargs):
nonlocal second_called
# Set to avoid infinite recursion
second_called = True
# New dispatcher call should hit the first callback again
self.assertFalse(first_called)
a, b = args
# Make a substraction here instead of add !
c = a - b
self.assertTrue(first_called)
return c
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
my_lib.fallback(second_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a - b)
self.assertTrue(first_called)
self.assertTrue(second_called)
def test_fallback_fallthrough(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
is_called = False
def my_fallback(op, *args, **kwargs):
nonlocal is_called
is_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
my_lib.fallback(my_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a + b)
self.assertTrue(is_called)
@unittest.skip(
"Causing flakiness, see https://github.com/pytorch/pytorch/issues/145108"
)
def test_fallthrough_for_dense_key_with_meta_in_tls(self) -> None:
# This tests that if meta is included in TlS dispatch key set,
# then a meta kernel should be called regardless if a dense
# backend has a fallthrough kernel
a = torch.randn((3, 3))
with _scoped_library("custom", "DEF") as my_lib:
my_lib.define("sum(Tensor self) -> Tensor")
meta_is_called = False
def sum_meta(*args, **kwargs):
nonlocal meta_is_called
meta_is_called = True
return args[0]
my_lib.impl("sum", fallthrough_kernel, "CPU")
my_lib.impl("sum", sum_meta, "Meta")
with torch._C._IncludeDispatchKeyGuard(torch.DispatchKey.Meta):
torch.ops.custom.sum.default(a)
self.assertTrue(meta_is_called)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2:
with _scoped_library("aten", "IMPL") as my_lib1:
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl("neg", my_neg, "AutogradCPU")
self.assertTrue(torch.neg(x).is_neg())
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(
RuntimeError, "operator name does not match namespace"
):
with _scoped_library("foo", "DEF") as my_lib3:
my_lib3.define("neg(Tensor self) -> Tensor")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overridden the behavior for the same before them
with self.assertRaisesRegex(
RuntimeError, "already a kernel registered from python"
):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
self.assertTrue(torch.mul(x, y)._is_zerotensor())
def test_error_if_fn_not_callable(self):
with self.assertRaisesRegex(
TypeError, "Input function is required to be a callable"
):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
def test_finalizer(self):
impls_refcnt = sys.getrefcount(torch.library._impls)
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
lib.define("foo123(Tensor x) -> Tensor")
# 1 for `lib`, 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib), 2)
# We gained an additional reference that gets cleared when the finalizer runs
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
# 1 for `lib`
# 1 for the finalizer
# 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib._op_impls), 3)
def foo123(x):
pass
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
key = f"{self.test_ns}/foo123/CPU"
self.assertTrue(key in torch.library._impls)
saved_op_impls = lib._op_impls
# del will definitely work if the following passes
self.assertEqual(sys.getrefcount(lib), 2)
del lib
# 1 for saved_op_impls
# 1 for sys.getrefcount
# This function should be the last user of lib._op_impls:
# - lib should not have a reference anymore (it was del'ed)
# - lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(saved_op_impls), 2)
self.assertTrue(key not in torch.library._impls)
# lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
def test_override_cpu_sum(self) -> None:
# Example 1
run = [False]
def my_sum(*args, **kwargs):
run[0] = True
return args[0].clone()
with _scoped_library("aten", "IMPL") as my_lib1:
my_lib1.impl("aten::sum", my_sum, "CPU")
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
self.assertTrue(run[0])
# Validate that the old behavior is restored for sum
self.assertEqual(torch.sum(x), torch.tensor(3))
def test_override_cuda_with_jiterator(self) -> None:
def override_where_cuda() -> None:
# Example 1: Invert the behavior of where's condition input
not_where_code_string = """
template <typename T> T inverted_where(bool cond, T a, T b){
return !cond ? a : b;
}
"""
jitted_where = _create_jit_fn(not_where_code_string)
CALLED = [False]
def inverted_where(*args, **kwargs):
CALLED[0] = True
return jitted_where(*args, **kwargs)
# overriding where's cuda kernel with Jiterator generated kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::where.self", inverted_where, "CUDA")
device = "cuda"
cond = torch.tensor(
[True, True, False], device=device, dtype=torch.bool
)
x = torch.tensor([1, 2, 3], device=device)
y = torch.tensor([-1, -2, -3], device=device)
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
def override_gelu_cuda() -> None:
# Example 2: Use relu to approximate gelu for faster compute
fastest_gelu_code_string = """
template <typename T> T fast_gelu(T a){
return a > 0 ? a : 0;
}
"""
jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
CALLED = [False]
def fast_gelu(*args, **kwargs):
CALLED[0] = True
return jitted_gelu(*args, **kwargs)
# overriding gelu's cuda kernel with Jiterator generated relu kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::gelu", fast_gelu, "CUDA")
x = torch.rand([3, 3], device="cuda", dtype=torch.float)
self.assertEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertNotEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
def override_exp_cuda() -> None:
# Example 3: Preventing exp from exploding for float16
clipped_exp_code_string = """
template <typename T> T clipped_exp(T a){
return a > T(10.0) ? T(22026.4657948) : exp(a);
}
"""
jitted_exp = _create_jit_fn(clipped_exp_code_string)
CALLED = [False]
def clipped_exp(*args, **kwargs):
CALLED[0] = True
return jitted_exp(*args, **kwargs)
# overriding exp's cuda kernel with clipped_exp kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::exp", clipped_exp, "CUDA")
x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
self.assertEqual(
torch.exp(x),
torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(
torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
)
def override_add_cuda() -> None:
# Example 4: simulate a hardware bug, where the adder is always off by 1
buggy_add_code_string = """
template <typename T> T buggy_add(T a, T b){
return a + b + T(1);
}
"""
jitted_add = _create_jit_fn(buggy_add_code_string)
CALLED = [False]
def buggy_add(*args, **kwargs):
CALLED[0] = True
return jitted_add(*args, **kwargs)
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")
x_cpu = torch.rand([3, 3], device="cpu")
y_cpu = torch.rand([3], device="cpu")
x_cuda = x_cpu.cuda()
y_cuda = y_cpu.cuda()
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
if torch.cuda.is_available() and not TEST_WITH_ROCM:
override_where_cuda()
override_gelu_cuda()
override_exp_cuda()
override_add_cuda()
def test_extend_library_with_dispatch_key_arg(self):
def my_sum(*args, **kwargs):
return args[0].clone()
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
with self.assertRaisesRegex(
RuntimeError, "inconsistent with the dispatch key"
):
my_lib1.impl("sum", my_sum, "Conjugate")
my_lib1.impl("aten::sum", my_sum)
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
def test_create_new_library(self) -> None:
with _scoped_library(self.test_ns, "DEF") as my_lib1:
my_lib1.define("sum(Tensor self) -> Tensor")
# Example 1
@torch.library.impl(my_lib1, "sum", "CPU")
def my_sum(*args, **kwargs):
return args[0].clone()
x = torch.tensor([1, 2])
op = getattr(torch.ops, self.test_ns).sum
self.assertEqual(op(x), x)
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
# Example 2
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
def my_sum_zt(*args, **kwargs):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
else:
return args[0].clone()
y = torch._efficientzerotensor(3)
self.assertTrue(op(y)._is_zerotensor())
self.assertEqual(op(x), x)
def test_create_new_library_fragment_no_existing(self):
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
my_lib.define("sum2(Tensor self) -> Tensor")
@torch.library.impl(my_lib, "sum2", "CPU")
def my_sum(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
def test_create_new_library_fragment_with_existing(self):
with _scoped_library(self.test_ns, "DEF") as my_lib1:
# Create a fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
my_lib2.define("sum4(Tensor self) -> Tensor")
@torch.library.impl(my_lib2, "sum4", "CPU")
def my_sum4(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
# Create another fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
my_lib3.define("sum3(Tensor self) -> Tensor")
@torch.library.impl(my_lib3, "sum3", "CPU")
def my_sum3(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
def test_alias_analysis(self):
def test_helper(alias_analysis=""):
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
called = [0]
@torch.library.define(
my_lib1, "_op() -> None", alias_analysis=alias_analysis
)
def _op(*args, **kwargs):
called[0] += 1
@torch.jit.script
def _test():
torch.ops._test_python_registration._op()
assert "_test_python_registration::_op" in str(_test.graph)
with self.assertRaises(AssertionError):
test_helper("") # alias_analysis="FROM_SCHEMA"
test_helper("CONSERVATIVE")
def test_error_for_unsupported_ns_or_kind(self) -> None:
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
my_lib1 = Library("myns", "BLA") # noqa: TOR901
for kind in ("DEF", "FRAGMENT"):
with self.assertRaisesRegex(ValueError, "reserved namespace"):
my_lib1 = Library("prim", kind) # noqa: TOR901
def test_returning_symint(self) -> None:
shape_env = ShapeEnv()
fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
s0, s1 = ft.shape
with _scoped_library(self.test_ns, "DEF") as tlib:
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
def sqsum(a: SymInt, b: SymInt):
return a * a + b * b
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
out_val = shape_env.evaluate_expr(out.node.expr)
self.assertEqual(out_val, 13)
def test_register_functional_op_error_cases(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
register_functional_op(lib, "abs", torch.ops.aten.abs_)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
register_functional_op(lib, "abs", torch.ops.aten.abs.out)
schemas = [
"foo(Tensor x, Tensor(a!)[] y) -> ()",
"foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)",
"foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))",
]
for schema in schemas:
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(schema)
with self.assertRaisesRegex(RuntimeError, "NYI"):
register_functional_op(
lib,
"foo_functional",
getattr(torch.ops, self.test_ns).foo.default,
)
def _check_is_functional_variant(self, mutable_op, functional_op, args):
# functional op should not mutate
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
functional_result = functional_op(*cloned_args)
self.assertEqual(cloned_args, args)
# check functional_result includes mutable_result
mutable_result = mutable_op(*cloned_args)
if mutable_result is None:
flat_mutable_result = []
else:
flat_mutable_result = pytree.tree_leaves(mutable_result)
flat_functional_result = pytree.tree_leaves(functional_result)
assert len(flat_functional_result) > len(flat_mutable_result)
self.assertEqual(
flat_functional_result[: len(flat_mutable_result)], flat_mutable_result
)
# check rest of functional_result is the mutated args
mutated_args = [
maybe_mutated_arg
for maybe_mutated_arg, arg in zip(cloned_args, args)
if not (
maybe_mutated_arg is not None
and arg is not None
and torch.allclose(maybe_mutated_arg, arg)
)
]
self.assertEqual(
flat_functional_result[len(flat_mutable_result) :], mutated_args
)
# check that functionalization kernel was indeed registered
def fn(*args):
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
mutable_op(*cloned_args)
return cloned_args
gm = make_fx(torch.func.functionalize(fn))(*args)
has_functional_op = False
for node in gm.graph.nodes:
self.assertFalse(node.target is mutable_op)
if node.target is functional_op:
has_functional_op = True
self.assertTrue(has_functional_op)
def test_register_functional_op_no_returns(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()")
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_functional_op_with_optional(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
z.fill_(2.71)
if w is not None:
w.fill_(1.618)
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, None),
)
def test_register_functional_op_one_return(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
z.fill_(0.99)
return x.clone()
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_functional_op_multiple_returns(self):
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
lib.define(
"foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)"
)
def foo_impl(x, y, z, w):
y.fill_(3.14)
w.fill_(2.71)
return x.clone(), z.clone()
lib.impl("foo", foo_impl, "CPU")
register_functional_op(
lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
)
x = torch.randn([])
y = torch.randn([])
z = torch.randn([])
w = torch.randn([])
self._check_is_functional_variant(
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default,
(x, y, z, w),
)
def test_register_fallthrough(self):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# dtype for mm should be float32 since we registered a fallthrough
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
# ops that don't have a fallthrough registered should not be affected
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# default behavior should have been restored
self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
y = x * x
saved_x = y.grad_fn._saved_self
grad_y = LoggingTensor(torch.tensor([1.0]))
log_input("grad_y", grad_y)
(g,) = torch.autograd.grad((y,), (x,), (grad_y,))
self.assertEqual(g.elem, torch.tensor([6.0]))
with torch.no_grad():
self.assertEqual(saved_x, x)
self.assertEqual(saved_x._version, x._version)
x.add_(2)
self.assertEqual(saved_x, x)
# TODO: figure out why broken
# self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
)
def test_out(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.zeros(1))
log_input("x", x)
log_input("y", y)
torch.abs(x, out=y)
self.assertEqual(y.elem, torch.ones(1))
# TODO: arguably this shouldn't pass and we should complain
# that out isn't a kwarg
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('y')
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
)
def test_kwarg_only(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.ones(1, 1))
z = LoggingTensor(torch.ones(1))
log_input("x", x)
log_input("y", y)
log_input("z", z)
torch.addmv(x, y, z)
torch.addmv(x, y, z, beta=1)
torch.addmv(x, y, z, beta=2)
torch.addmv(x, y, z, alpha=2)
torch.addmv(x, y, z, beta=2, alpha=2)
# The expectation is that beta/alpha don't show up when they're
# defaulted. This is even if the user explicitly specified it.
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1, 1] = input('y')
$2: f32[1] = input('z')
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
)
def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
log_input("x", x)
torch.ops.aten._foobar(x)
torch.ops.aten._foobar(x, False)
torch.ops.aten._foobar(x, arg3=False)
torch.ops.aten._foobar(x, False, arg3=False)
# What we are testing here is that we omit arg2
# if it is defaulted, even if a kwarg is set
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten._foobar.default($0)
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
)
def test_produce_real_type(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
x.to(dtype=torch.double) # non-optional dtype
torch.cumprod(x, 0, dtype=torch.double) # optional dtype
x[:, 1].contiguous(
memory_format=torch.contiguous_format
) # optional memory format
# There doesn't appear to be any layout signatures which are
# triggerable using tensor subclasses (need to use a mode)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
)
def test_optional_tensor_list(self) -> None:
def weird(xs):
print("woof")
return torch.empty(())
with _scoped_library("my_lib", "DEF") as my_lib:
my_lib.define("weird(Tensor?[] self) -> Tensor")
my_lib.impl("weird", weird, "CPU")
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
torch.ops.my_lib.weird.default([None, x])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
)
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
for list_type in (list, tuple):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket == torch.ops.aten.split:
with no_dispatch():
return list_type(torch.split(*args))
else:
raise AssertionError(f"unrecognized func: {func}")
self.assertEqual(
torch.split(A(torch.tensor([0, 1])), 2),
torch.split(torch.tensor([0, 1]), 2),
)
def test_invalid_ret(self) -> None:
# test invalid return gets reasonable error message
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return "arf"
# Wobbles depending on NDEBUG mode of pybind11
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).neg(),
)
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).detach(),
)
def test_detach_appears_twice_when_called_once(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
x.detach()
# FIXME: We actually want this to emit a single detach. However,
# it currently emits two, for reasons unclear to us. Leaving
# this test here to make sure we don't regress even further (it
# would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)""",
)
def test_storage(self) -> None:
# For now, just make sure it doesn't crash. Ideally, we should
# return some virtual storage that is safe to work with
x = LoggingTensor(torch.ones(1))
storage = x.untyped_storage()
self.assertRaises(RuntimeError, lambda: storage.data_ptr())
def test_make_wrapper_subclass_noalloc(self) -> None:
# This is ludicrously big (8TB) and this should pass because wrapper
# subclasses don't allocate
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
def test_version(self) -> None:
x = LoggingTensor(torch.ones(1))
prev_vc = x._version
x.detach().add_(2)
cur_vc = x._version
self.assertNotEqual(prev_vc, cur_vc)
x.data.add_(2)
self.assertEqual(cur_vc, x._version)
def test_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
# The big tests for code coverage are test_precedence_semantics in
# test_overrides.py; this is just to make sure it is wired up at all
# correctly for __torch_dispatch__
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorB
self.assertRaises(
ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
)
def test_format(self) -> None:
x = LoggingTensor(torch.ones(1))
s1 = str(x)
s2 = repr(x)
s3 = f"{x}"
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
self.assertEqual(s1, s2)
self.assertEqual(s1, s3)
def test_custom_autograd(self) -> None:
escape = [None]
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x**2
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, grad_output):
assert isinstance(grad_output, LoggingTensor)
(x,) = ctx.saved_tensors
assert isinstance(x, LoggingTensor)
escape[0] = x
return grad_output * 2 * x
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1), requires_grad=True)
log_input("x", x)
x.grad = LoggingTensor(torch.zeros(1))
log_input("x.grad", x.grad)
y = Square.apply(x)
grad_output = LoggingTensor(torch.ones(1))
log_input("grad_output", grad_output)
y.backward(grad_output)
with torch.no_grad():
self.assertEqual(escape[0], x)
self.assertEqual(escape[0]._version, x._version)
# TODO: figure out why x.requires_grad = False doesn't
# trigger an error for LoggingTensor
x.add_(2)
self.assertEqual(escape[0], x)
# TODO: figure out why this is broken
# self.assertEqual(escape[0]._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
)
def test_subclass_creation(self):
# Make sure these statements runs without error
# In particular checking that when internal detach returns
# subclasses, these are cleanly overwritten.
class Foo(torch.Tensor):
pass
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(RuntimeError, err_msg):
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
with self.assertRaisesRegex(RuntimeError, err_msg):
Foo(LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
def test_new_ones(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
def test_like(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
for f in ["empty", "ones", "rand", "randn", "zeros"]:
f_name = f + "_like"
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
def test_make_fx_with_subclass(self) -> None:
def f(x, y):
# Returns (TwoTensor, Tensor)
return x * y, y + y
x_a = torch.zeros(4)
x_b = torch.zeros(4)
y = torch.ones(4)
# make_fx() is not responsible for unwrapping tensor subclass inputs,
# so we do it manually here.
# Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
# convention as f(*args). Unwrapping tensor subclass inputs can potentially change
# the number of input args to the graph, breaking that assumption
def f_to_trace(x_a, x_b, y):
x = TwoTensor(x_a, x_b)
out1, out2 = f(x, y)
out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
self.assertExpectedInline(
fx_g.code,
"""\
def forward(self, x_a_1, x_b_1, y_1):
mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None
mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None
add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None
return (mul, mul_1, add)
""",
)
# See https://github.com/pytorch/pytorch/issues/117794
def test_return_and_correct_aliasing_gives_correct_stride(self):
t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
x = torch.randn(2, 2)
# slicing should result in the same stride for TwoTensor as a dense tensor would give
self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
class WrapperTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
strides=elem.stride(),
storage_offset=elem.storage_offset(),
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError("NYI")
# non-contiguous strides, non-zero storage offset
x = torch.randn(4, 6).t().diagonal(offset=2)
y = WrapperTensor(x)
self.assertEqual(y.size(), x.size())
self.assertEqual(y.stride(), x.stride())
self.assertEqual(y.storage_offset(), x.storage_offset())
def test_wrapper_subclass_serializes(self) -> None:
with tempfile.TemporaryFile() as f:
# purposefully use int64 to test non-default dtype
x = LoggingTensor(torch.randperm(3))
torch.save(x, f)
f.seek(0)
with torch.serialization.safe_globals([LoggingTensor]):
x_loaded = torch.load(f)
self.assertTrue(type(x_loaded) is type(x))
self.assertEqual(x, x_loaded)
self.assertEqual(x.elem, x_loaded.elem)
self.assertFalse(x is x_loaded)
def test_deepcopy_wrapper_subclass(self) -> None:
# purposefully use int64 to test non-default dtype
x = LoggingTensor(torch.randperm(3))
x_copy = deepcopy(x)
self.assertTrue(type(x_copy) is type(x))
self.assertEqual(x, x_copy)
self.assertEqual(x.elem, x_copy.elem)
self.assertFalse(x is x_copy)
def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(
self,
) -> None:
class MyWrapperTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
strides=elem.stride(),
storage_offset=elem.storage_offset(),
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket.__name__ == "clone":
# Return a plain tensor from clone().
return args[0].elem.clone()
raise RuntimeError("NYI")
# NB: The default Tensor.__torch_function__ implementation called for deepcopy
# disables __torch_function__ by the time we get to clone(), so there is no need to
# explicitly disable __torch_function__ for this subclass.
x = MyWrapperTensor(torch.randn(3))
with self.assertRaisesRegex(
RuntimeError,
"for which cloning returns another instance of the same subclass",
):
x_copy = deepcopy(x)
def test_deepcopy_non_wrapper_subclass(self) -> None:
# Ensure correct error is thrown for common error cases.
class SubTensorError1(torch.Tensor):
# Default implementation of new_empty() returns a plain tensor.
pass
class SubTensorError2(torch.Tensor):
# new_empty() incorrectly returns a different type (i.e. a plain tensor).
def new_empty(self, shape):
return torch.Tensor(shape)
for error_cls in [SubTensorError1, SubTensorError2]:
x = error_cls(3)
with self.assertRaisesRegex(
RuntimeError,
"for which that function returns another instance of the same subclass",
):
x_copy = deepcopy(x)
# Ensure a correctly implemented new_empty() causes deepcopy() to work.
class SubTensorSuccess(torch.Tensor):
def new_empty(self, shape):
return type(self)(shape)
x = SubTensorSuccess(3)
x_copy = deepcopy(x)
self.assertIs(type(x_copy), type(x))
def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
class ExtraKeysTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# NB: only the non-kwarg overload of _make_wrapper_subclass supports
# extra dispatch keys. We probably want to unify the two APIs
# in the future.
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.size(),
elem.stride(),
elem.storage_offset(),
torch.contiguous_format,
elem.dtype,
elem.layout,
elem.device,
False,
False,
None,
False,
False,
DispatchKeySet(DispatchKey.NestedTensor),
)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
pass
x = ExtraKeysTensor(torch.randn(3))
self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
self.assertFalse(
torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor)
)
def test_wrapper_subclass_multiprocessing_preserves_dtype(self):
# a and b have dtype of int64, which is purposefully different from the default
# assumed by _make_wrapper_subclass().
a = torch.randperm(5)
b = torch.randperm(5)
data = TwoTensor(a, b)
expected_dtype = data.dtype
loader = torch.utils.data.DataLoader(
[data, data],
batch_size=2,
num_workers=2,
collate_fn=_identity,
)
for batch in loader:
self.assertEqual(batch[0].dtype, expected_dtype)
def test_index_put_where_only_index_is_subclass(self) -> None:
called_funcs = []
class MyTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called_funcs.append(func)
return MyTensor(torch.tensor(3))
x = torch.randn(3, 3)
idxs = (MyTensor(torch.tensor(0)),)
v = torch.randn(1)
res = x.index_put_(idxs, v)
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
def test_torch_dispatch_mode_basic(self) -> None:
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
torch.empty([])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
)
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
x + y
self.assertExpectedInline(
"\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)"""
)
def test_nested_push_logging_tensor_mode(self):
x = torch.randn([])
y = torch.randn([])
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
with LoggingTensorMode():
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
def test_capture_logs_with_torch_dispatch_mode(self):
x = torch.randn([])
y = torch.randn([])
with capture_logs_with_logging_tensor_mode() as logs:
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
x = torch.randn([])
y = torch.randn([])
with capture_logs_with_logging_tensor_mode() as logs1:
with capture_logs_with_logging_tensor_mode() as logs2:
torch.empty([])
x + y
self.assertExpectedInline(
"\n".join(logs2),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
)
self.assertEqual(logs1, logs2)
def test_torch_dispatch_mode_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with AMode():
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with BMode():
func(*args, **kwargs)
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA
class BMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorB
a = A(torch.empty(1))
b = B(torch.empty(1))
with self.assertRaises(ErrorA):
a + a
with self.assertRaises(ErrorB):
a + b
# B has precedence over A due to the subclass relationship yet
# modes take precedence over arguments
with self.assertRaises(ErrorA):
with AMode():
b + b
with self.assertRaises(ErrorB):
with BMode():
a + a
with self.assertRaises(ErrorB):
with BMode():
a + b
def test_mode_with_make_subclass(self):
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
class BasicMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
x = torch.randn(3)
with BasicMode():
y = SubTensor(x)
self.assertIsInstance(y, SubTensor)
def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
with capture_logs(is_mode=True) as logs1:
with LoggingTensorMode():
torch.ones([2, 3])
with no_dispatch():
torch.ones([2, 3])
with capture_logs(is_mode=True) as logs2:
with LoggingTensorMode():
torch.ones([2, 3])
self.assertEqual(logs1, logs2)
def test_shallow_copy_and_detach(self) -> None:
seen = set()
test_case = self
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
tree_map_only(
torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs)
)
if kwargs is None:
kwargs = {}
r = func(*args, **kwargs)
tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
return r
with TestMode():
x = torch.randn(3, requires_grad=True)
loss = (x * x).sum()
loss.backward()
def test_exception_handling(self):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
class AMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if func.__name__ == "randn.default":
raise RuntimeError
return A(torch.zeros(()))
with AMode():
try:
torch.randn(())
except RuntimeError:
pass
self.assertTrue(isinstance(torch.zeros(()), A))
def test_with_mode_created_separately(self):
class ErrorA(RuntimeError):
pass
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA
x = A()
with self.assertRaises(ErrorA):
with x:
torch.empty([])
def test_with_nested_modes(self):
class ErrorA(RuntimeError):
def __init__(self, msg):
super().__init__(msg)
class A(TorchDispatchMode):
def __init__(self, msg):
self.msg = msg
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise ErrorA(self.msg)
with self.assertRaisesRegex(ErrorA, "layer2"):
with A("layer1"):
with A("layer2"):
torch.empty([])
def test_make_subclass_with_modes(self):
class ModeTensor(torch.Tensor):
def __new__(cls, elem, mode):
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
r.elem = elem
r.mode = mode
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise NotImplementedError("Shouldn't be here")
class Mode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, ModeTensor):
return e.elem
else:
return e
def wrap(t):
if isinstance(t, torch.Tensor):
return ModeTensor(t, self)
else:
return t
return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
class BasicMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)
x = torch.tensor(4.0)
with Mode():
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
with Mode():
with BasicMode(): # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
y = x + x
z = y + y
self.assertIsInstance(y, ModeTensor)
self.assertIsInstance(z, ModeTensor)
assert self.assertRaisesRegex(
RuntimeError,
"subclass Mode but.* associated to a python object of type Mode",
)
def test_notimplemented_mode(self):
sub_count = 0
class PoliteMode(TorchDispatchMode):
def __init__(self) -> None:
self.pre_count = 0
self.post_count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.pre_count += 1
if any(t is not torch.Tensor for t in types):
return NotImplemented
self.post_count += 1
return func(*args, **kwargs)
class SubTensor(torch.Tensor):
def __new__(cls, elem):
r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
nonlocal sub_count
sub_count += 1
def unwrap(t):
if isinstance(t, SubTensor):
return t.elem
else:
return t
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
a = SubTensor(torch.randn(2))
with PoliteMode() as mode:
a.abs()
self.assertEqual(mode.pre_count, 2)
self.assertEqual(mode.post_count, 1)
self.assertEqual(sub_count, 1)
# make sure this doesn't error
with PoliteMode():
with PoliteMode():
a.abs()
def test_nesting_same_mode(self):
# If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode() as reenabled:
with reenabled:
torch.empty([])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
)
def test_error_using_class_method_on_mode(self):
class A(TorchDispatchMode):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return func(args, kwargs)
x = torch.tensor(5.0)
with self.assertRaisesRegex(
RuntimeError, "classmethod is not supported, please make it a plain method"
):
with A():
x + x
def test_get_cur_mode(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(_get_current_dispatch_mode(), None)
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode(), mode1)
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode(), mode2)
def test_get_mode_stack(self):
class A(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
pass
self.assertEqual(_get_current_dispatch_mode_stack(), [])
with A() as mode1:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
with mode1:
with A() as mode2:
self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
def test_all_same_mode(self):
x = LoggingTensorMode()
y = LoggingTensorMode()
self.assertTrue(all_same_mode([x, x, x]))
self.assertFalse(all_same_mode([x, None]))
self.assertFalse(all_same_mode([x, y]))
def test_mode_detection(self):
class InfraMode(TorchDispatchMode):
@classmethod
def is_infra_mode(cls):
return True
class NonInfraMode(TorchDispatchMode):
pass
with InfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
with NonInfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
with InfraMode():
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(
is_in_torch_dispatch_mode(include_infra_modes=False)
)
self.assertTrue(is_in_torch_dispatch_mode())
self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
self.assertTrue(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
self.assertFalse(is_in_torch_dispatch_mode())
self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
x = LoggingTensor(torch.tensor([2.0, 3.0]))
with self.assertRaisesRegex(
RuntimeError, "is not supported for tensor subclasses."
):
x.tolist()
with self.assertRaisesRegex(
RuntimeError, "is not supported for tensor subclasses."
):
x.numpy()
with self.assertRaises(AssertionError):
self.assertEqual(x, None)
# See https://github.com/pytorch/pytorch/issues/136064
def test_view_returns_alias_under_torch_dispatch(self):
class MyMode(TorchDispatchMode):
def __init__(self, testcase):
self.testcase = testcase
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
out = func(*args, **kwargs)
if func == torch.ops.aten.view.dtype:
# view should return a fresh TensorImpl
self.testcase.assertTrue(out is not args[0])
return out
with MyMode(self):
x = torch.ones(4, dtype=torch.float32)
out = x.view(torch.float32)
def test_record_stream(self) -> None:
class TestMode(TorchDispatchMode):
def __init__(self, testcase):
self.testcase = testcase
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.testcase.assertEqual(func.name(), "aten::record_stream")
self.testcase.assertIsInstance(args[0], torch.Tensor)
self.testcase.assertIsInstance(args[1], torch.Stream)
self.testcase.assertEqual(args[1].stream_id, 1)
self.testcase.assertEqual(args[1].device_index, 2)
self.testcase.assertEqual(args[1].device_type, 3)
t = torch.tensor(5.0)
s = torch.Stream(stream_id=1, device_index=2, device_type=3)
with TestMode(self):
t.record_stream(s)
def test_return_stream(self) -> None:
with _scoped_library("test_return_stream", "DEF") as l_def:
l_def.define("return_stream(Tensor self) -> Stream")
with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
l_impl.impl(
"return_stream",
lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2),
)
class TestMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return torch.Stream(stream_id=1, device_index=2, device_type=3)
t = torch.tensor(5.0)
s = torch.ops.test_return_stream.return_stream(t)
self.assertIsInstance(s, torch.Stream)
self.assertEqual(s.stream_id, 0)
self.assertEqual(s.device_index, 1)
self.assertEqual(s.device_type, 2)
with TestMode():
s = torch.ops.test_return_stream.return_stream(t)
self.assertIsInstance(s, torch.Stream)
self.assertEqual(s.stream_id, 1)
self.assertEqual(s.device_index, 2)
self.assertEqual(s.device_type, 3)
def test_subclass_autograd_device_check(self) -> None:
class NonWrapperSubclass(torch.Tensor):
elem: torch.Tensor
__slots__ = ["elem"]
@staticmethod
def __new__(cls, elem, *args, **kwargs):
# Wrong device here!
r = torch.Tensor._make_subclass(
cls, elem.to("meta"), elem.requires_grad
)
# ...the real tensor is held as an element on the tensor.
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, NonWrapperSubclass) else e
def wrap(e):
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
logging.getLogger("NonWrapperSubclass").info(
f"{func.__module__}.{func.__name__}", # noqa: G004
args,
kwargs,
rs,
)
return rs
x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
y = torch.randn(2, requires_grad=True)
z = x * y
self.assertIsInstance(z, NonWrapperSubclass)
z.sum().backward(torch.tensor(1))
self.assertEqual(x.grad, y)
self.assertEqual(y.grad, x)
def test_none_wrapping(self):
# A Tensor subclass that returns None when doing add
# See LoggingTensor above for more details on the subclass
class SubclassWithNone(torch.Tensor):
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
layout=elem.layout,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
return e.elem if isinstance(e, SubclassWithNone) else e
def wrap(e):
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
rs = tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
if func.overloadpacket.__name__ == "add":
return None
else:
return rs
x = SubclassWithNone(torch.rand(2))
# Make sure both run without error
self.assertIsInstance(x * 2, SubclassWithNone)
self.assertIsNone(x + 2)
x.requires_grad_()
out = x.acos().sum()
# The backward of acos does add then rsqrt so here we make sure that the
# undefined Tensor generated by the user code is nicely handled.
# If acos formula changes in the future, this can be replaced by any other
# function that does add then something in the backward in a composite way
with self.assertRaisesRegex(RuntimeError, "but got None"):
out.backward()
def test_storage_can_be_converted_to_python_object(self):
s = torch.Storage()
z = LoggingTensor(torch.empty([]))
z.set_(s)
def test_autograd_in_attr(self):
# We want the wrapped Tensor to require gradients!
true_t = torch.rand(2, requires_grad=True)
t = LoggingTensorReentrant(true_t)
out = t + 2
self.assertFalse(out.requires_grad)
self.assertIsNone(out.grad_fn)
self.assertTrue(out.elem.requires_grad)
self.assertIsNotNone(out.elem.grad_fn)
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
out.sum().backward()
out.elem.sum().backward()
self.assertIsNone(t.grad)
self.assertIsNotNone(t.elem.grad)
def test_dispatch_super_call(self):
called = []
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
return super().__torch_dispatch__(func, types, args, kwargs)
x = torch.randn(2)
y = torch.randn(2)
self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
self.assertEqual(called, [torch.ops.aten.add.Tensor])
def test_dispatch_super_call_list_arg(self):
called = []
class SubTensorWithListArg(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
return super().__torch_dispatch__(func, types, list(args), kwargs)
x = torch.randn(2)
self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
self.assertEqual(called, [torch.ops.aten.neg.default])
def test_dispatch_super_dont_autograd(self):
called = []
class SubTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
called.append(func)
# This argument still requires grad because it was passed
# through directly...
self.assertTrue(args[0].requires_grad)
r = super().__torch_dispatch__(func, types, args, kwargs)
# But the output better not require grad, because that means
# you did autograd again in torch dispatch (oops)
self.assertFalse(r.requires_grad)
return r
x = SubTensor(torch.randn(2, requires_grad=True))
x.neg()
self.assertEqual(called, [torch.ops.aten.neg.default])
def test_set_data(self):
called = 0
class SubTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
nonlocal called
called += 1
return super().__torch_dispatch__(func, types, args, kwargs)
x = SubTensor(torch.empty(2))
x.data
self.assertEqual(called, 1)
x.data = torch.empty(2)
self.assertEqual(called, 1)
x.data
self.assertEqual(called, 2)
self.assertIs(type(x), SubTensor)
x.set_(torch.empty(2))
self.assertEqual(called, 3)
x.data
self.assertEqual(called, 4)
self.assertIs(type(x), SubTensor)
def test_construct_int_tensor(self):
class SubTensor(torch.Tensor):
pass
# should not fail
SubTensor(torch.zeros(2, dtype=torch.int))
def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that!
class MySubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_subclass(cls, elem)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch():
return func(*args, **kwargs)
x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
y = x.conj()
# Details of the bug that this tests for:
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
# There are a few calls to the dispatcher that are going to happen here:
# - call_exp: User calling exp on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - AutogradCPU: no input requires grad, so does nothing and redispatch
# - Conjugate: no special implementation for exp: use the fallback that
# first clone the Tensor (to materialize the conj) then redispatch
# - call_clone: conjugate fallback calling clone on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
# - Conjugate: special implementation for clone: just skip this key
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
# actually calls into the dispatcher again but since we disable both our keys
# before, not detailed here)
# - exit Python: restore the TLS and exit
# - exit Conjugate: nothing was inplace so just exit
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
# - More steps....
y.exp()
@staticmethod
def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
if use_wrapper_subclass:
kwargs["device"] = data.device
kwargs["dtype"] = data.dtype
kwargs["layout"] = data.layout
kwargs["requires_grad"] = True
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
else:
return torch.Tensor._make_subclass(cls, data, True, **kwargs)
def test_is_contiguous_slow_path(self):
data = torch.randn(3, 3)
contiguous_data = data.clone()
not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
for use_wrapper_subclass in [True, False]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return contiguous_data.is_contiguous()
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.is_contiguous:
return not_contiguous_data.is_contiguous()
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.is_contiguous()
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), True)
e.contiguous() # this will just return the original TensorImpl since is_contiguous = True
err_msg = "Multiple dispatch failed for"
e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.is_contiguous(), False)
with self.assertRaisesRegex(TypeError, err_msg):
e.contiguous()
def test_fancy_strides(self):
calls = []
class ExampleTensor(torch.Tensor):
@staticmethod
def __new__(cls, data):
return TestPythonDispatch.subclass_helper(
cls, data, False, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in [
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.is_strides_like_format.default,
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.stride.default,
]:
calls.append((func, list(args)[1:]))
return None
with no_dispatch():
return func(*args, **kwargs)
e = ExampleTensor(torch.randn(2, 2))
self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(
calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])]
)
calls.clear()
self.assertFalse(
torch.ops.aten.is_strides_like_format.default(e, torch.channels_last)
)
self.assertEqual(
calls,
[(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])],
)
calls.clear()
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
self.assertEqual(
calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])]
)
def test_device_slowpath(self):
for use_wrapper_subclass in [True]:
class ExampleTensor1(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class ExampleTensor2(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.device:
return torch.device("meta")
return NotImplemented
class ExampleTensor3(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_device=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.device:
return torch.device("meta")
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
with self.assertRaisesRegex(TypeError, err_msg):
e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
e.device()
ten = torch.rand([1])
e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
self.assertEqual(e.device.type, "meta")
self.assertEqual(ten.type_as(e).device.type, "meta")
e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
self.assertEqual(e.device.type, "meta")
self.assertEqual(ten.type_as(e).device.type, "meta")
def test_dim_slowpath(self):
data = torch.randn(3, 3)
for use_wrapper_subclass in [True, False]:
class DimNotImplementedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class DimImplementedTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.dim()
t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(t.dim(), 2)
def test_maybe_tuple_bug(self):
class T(torch.Tensor):
@classmethod
def __torch_function__(cls, *args, **kwargs):
pass
a = torch.rand(3)
a[[T(), T()]]
def test_standard_is_not_subclass(self):
# https://github.com/pytorch/pytorch/issues/79079
self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
def test_sym_sizes_strides_slow_path(self):
class TestTensor(torch.Tensor):
@staticmethod
def __new__(cls, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, (0,), dispatch_sizes_strides_policy="sizes"
)
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func in (
torch.ops.aten.sym_size.default,
torch.ops.aten.sym_stride.default,
):
from torch._dynamo.source import ConstantSource
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
ShapeEnv,
)
shape_env = ShapeEnv()
si = shape_env.create_symintnode(
shape_env.create_symbol(
123,
source=ConstantSource("abc"),
dynamic_dim=DimDynamic.DUCK,
constraint_dim=None,
),
hint=123,
)
return (si,)
t = TestTensor()
si = t.size()[0]
self.assertIsInstance(si, torch.SymInt)
si = t.stride()[0]
self.assertIsInstance(si, torch.SymInt)
def test_strides_slow_path(self):
for use_wrapper_subclass in [True, False]:
class StridesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class StridesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_stride.default:
return (4, 2)
return NotImplemented
class StridesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="strides"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_stride.default:
return None
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.stride()
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.stride(), (4, 2))
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
self.assertEqual(e.stride(), (2, 1))
def test_sizes_slow_path(self):
for use_wrapper_subclass in [True, False]:
data = torch.randn(6, 2)
class SizesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
return NotImplemented
class SizesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.sym_size:
return (5, 3)
return NotImplemented
class SizesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.sym_size:
return None
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.size()
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.size(), (5, 3))
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.size(), (4, 2))
def test_custom_size_policy_dynamic_shapes(self):
data = torch.randn(6, 2)
class CustomSizeDynamicShapesTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
inner.size(),
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
"sizes",
)
def __init__(self, inner):
self.inner = inner
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_size.default:
return args[0].inner.shape
if func == torch.ops.aten.sym_stride.default:
return args[0].inner.shape
return NotImplemented
x = torch.ones(2, 2)
def trace_fn(x):
x_wrapper = CustomSizeDynamicShapesTensor(x)
return x_wrapper.size(), x_wrapper.stride()
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
self.assertExpectedInline(
fx_g.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""",
)
def test_data_ptr_respects_numel_slow_path(self):
data = torch.randn(6, 2)
class NumelDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.numel:
numel_called[0] = True
return None
return NotImplemented
for use_wrapper_subclass in (False, True):
numel_called = [False]
e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
e.data_ptr()
self.assertTrue(numel_called[0])
def test_layout_slow_path(self):
for use_wrapper_subclass in [True, False]:
data = torch.randn(6, 2)
class LayoutNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
return NotImplemented
class LayoutCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.layout:
return torch.sparse_csr
return NotImplemented
class LayoutDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(
cls, data, wrapper, dispatch_layout=True
)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.prim.layout:
return data.layout
return NotImplemented
err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.layout
e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.layout, torch.sparse_csr)
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.layout, torch.strided)
def test_wrapper_subclass_reentrant_dispatch_with_mode(self):
# Tests the interaction between a wrapper subclass using reentrant dispatch
# and a TorchDispatchMode. See https://github.com/pytorch/pytorch/issues/136565
# simple passthrough TorchDispatchMode
class CustomDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
return func(*args, **kwargs)
# derive from TwoTensor to minimize boilerplate
class MySubclass(TwoTensor):
def __torch_dispatch__(self, func, types, args, kwargs=None):
with torch.overrides.enable_reentrant_dispatch():
return func(args[0].a)
t = MySubclass(torch.rand(2), torch.rand(2))
with CustomDispatchMode():
res = t.clone()
self.assertEqual(res, t.a)
self.assertIs(type(res), torch.Tensor)
class TestPythonDispatcher(TestCase):
def test_basic(self):
x = torch.randn(2, requires_grad=True)
r = torch._C._EnablePythonDispatcher()
torch.add(x, x)
def test_lstsq(self):
a = torch.randn(4, 3)
b = torch.rand(4, 3)
expected_shape = torch.linalg.lstsq(a, b).solution.shape
r = torch._C._EnablePythonDispatcher()
python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
self.assertEqual(expected_shape, python_disp_shape)
class TestWrapperSubclassAliasing(TestCase):
def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
def to_subclass(t: torch.Tensor):
return TwoTensor(t, t.clone())
result_ref = op(*args, **kwargs)
args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
result_test = op(*args_subclass, **kwargs_subclass)
args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
args_ref_flat_tensors = [
x for x in args_ref_flat if isinstance(x, torch.Tensor)
]
args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
args_test_flat_tensors = [
x for x in args_test_flat if isinstance(x, torch.Tensor)
]
result_ref_flat = pytree.tree_leaves(result_ref)
result_ref_flat_tensors = [
x for x in result_ref_flat if isinstance(x, torch.Tensor)
]
result_test_flat = pytree.tree_leaves(result_test)
result_test_flat_tensors = [
x for x in result_test_flat if isinstance(x, torch.Tensor)
]
for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
out_is_inpt = o_ref is a_ref
if out_is_inpt:
self.assertTrue(o_test is a_test)
out_aliases_inpt = StorageWeakRef(
o_ref.untyped_storage()
) == StorageWeakRef(a_ref.untyped_storage())
if out_aliases_inpt:
self.assertTrue(
StorageWeakRef(o_test.untyped_storage())
== StorageWeakRef(a_test.untyped_storage())
)
else:
self.assertFalse(
StorageWeakRef(o_test.untyped_storage())
== StorageWeakRef(a_test.untyped_storage())
)
# This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
# a util for wrapper subclasses to promise correct aliasing behavior.
# It's probably overkill to test every OpInfo,
# so I picked a sampling of ops with representative schemas.
@ops(
[
op
for op in op_db
if op.name
in [
"mul", # out-of-place
"cat", # out-of-place (TensorList input)
"index", # out-of-place (Optional TensorList input)
"mul_", # inplace
"view", # view
"t_", # inplace-view
"split", # view (multi-return)
"native_batch_norm", # mutable op (returns outputs and mutates some inputs)
]
],
allowed_dtypes=(torch.float,),
)
def test_wrapper_subclass_aliasing(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
@ops(custom_op_db, allowed_dtypes=(torch.float,))
def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
sample = first_sample(self, samples)
args = (sample.input, *sample.args)
kwargs = sample.kwargs
self._test_wrapper_subclass_aliasing(op, args, kwargs)
def test_wrapper_subclass_aliasing_conv2d(self, device):
args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
kwargs = {}
# conv2d has a default arg 'int[2] strides=0',
# which torchscript expands into 'int[2] strides=[0, 0]'
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(
torch.ops.aten.conv2d.default, args, kwargs
)
def test_wrapper_subclass_aliasing_out_op(self, device):
# Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
args = (torch.ones(4), torch.ones(4))
kwargs = {"out": torch.empty(4)}
self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
def test_wrapper_subclass_aliasing_fft_fft2(self, device):
args = (torch.randn(4, 4),)
kwargs = {}
# fft_fft2 has a default arg 'int[1] dim=[-2,-1]',
# Make sure that _return_and_correct_aliasing can handle this case
# (I'm using inference_mode to make sure fft_fft2 doesn't decompose and goes to torch_dispatch)
with torch.inference_mode():
self._test_wrapper_subclass_aliasing(torch.ops.aten.fft_fft2, args, kwargs)
instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
if __name__ == "__main__":
run_tests()