Detect torch function in lists as well (#160256)

We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things.

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160256
Approved by: https://github.com/albanD
This commit is contained in:
Edward Yang
2025-09-02 00:51:52 -04:00
committed by PyTorch MergeBot
parent 524b78d4f6
commit 9a1c5c0a07
3 changed files with 436 additions and 52 deletions

View File

@ -615,6 +615,271 @@ class TestTorchFunctionOverride(TestCase):
self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
def test_torch_function_in_lists(self):
"""Test that __torch_function__ is called for objects inside lists"""
class IntLike:
"""Object that can be used in int lists"""
def __init__(self, value):
self.value = value
self.torch_function_called = False
def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return a result that makes the operation succeed
if func.__name__ == 'pad':
# For pad, return the input with shape adjusted
return args[0]
elif func.__name__ == 'layer_norm':
# For layer_norm, return normalized tensor
return torch.ones_like(args[0])
elif func.__name__ == 'tensordot':
# For tensordot, return appropriate shape
return torch.tensor(42.0)
# Fallback
return torch.tensor(42.0)
# Test with F.pad which takes int list
import torch.nn.functional as F
x = torch.randn(2, 3)
obj = IntLike(1)
# pad takes [left, right, top, bottom] as padding
_ = F.pad(x, [1, obj, 0, 0])
self.assertTrue(obj.torch_function_called,
"torch_function should be called for object in int list")
# Test multiple objects in list
obj1 = IntLike(1)
obj2 = IntLike(2)
_ = F.pad(x, [obj1, obj2, 0, 0])
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
"torch_function should be called for at least one object")
def test_torch_function_in_float_lists(self):
"""Test that __torch_function__ is called for objects inside float lists"""
class FloatLike:
"""Object that can be used in float lists"""
def __init__(self, value):
self.value = float(value)
self.torch_function_called = False
def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return appropriate result
if func.__name__ == 'layer_norm':
return torch.ones_like(args[0])
return torch.tensor(42.0)
import torch.nn.functional as F
x = torch.randn(2, 3, 4)
obj = FloatLike(4.0)
# layer_norm takes normalized_shape as int/float list
_ = F.layer_norm(x, [3, obj])
self.assertTrue(obj.torch_function_called,
"torch_function should be called for object in float list")
def test_torch_function_in_scalar_lists(self):
"""Test that __torch_function__ is called for scalar objects inside lists"""
class ScalarLike:
"""Object that can be used as a scalar in lists"""
def __init__(self, value):
self.value = value
self.torch_function_called = False
def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return a scalar tensor
return torch.tensor(self.value)
def __float__(self):
return float(self.value)
def __int__(self):
return int(self.value)
# Test with a function that takes scalar lists
# Using torch.as_tensor which can take scalar lists
obj1 = ScalarLike(1.0)
obj2 = ScalarLike(2.0)
# Create a tensor with scalar list containing torch function objects
# Use a different operation that should trigger torch_function
_ = torch.stack([obj1, obj2])
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
"torch_function should be called for scalar objects in list")
def test_torch_function_precedence_in_lists(self):
"""Test precedence when multiple torch function objects are in a list"""
call_order = []
class HighPriority:
def __torch_function__(self, func, types, args=(), kwargs=None):
call_order.append('high')
# Delegate to lower priority
return NotImplemented
class LowPriority:
def __torch_function__(self, func, types, args=(), kwargs=None):
call_order.append('low')
# Return valid result
if func.__name__ == 'pad':
return args[0]
return torch.tensor(42.0)
import torch.nn.functional as F
x = torch.randn(2, 3)
high = HighPriority()
low = LowPriority()
# Test with both objects in list
call_order.clear()
_ = F.pad(x, [1, high, low, 0])
# High priority should be called first
self.assertEqual(call_order[0], 'high',
"Higher priority torch_function should be called first")
self.assertEqual(call_order[1], 'low',
"Lower priority torch_function should be called after NotImplemented")
def test_torch_function_mixed_lists(self):
"""Test lists with mix of regular values and torch function objects"""
class CountingInt:
call_count = 0
def __init__(self, value):
self.value = value
@classmethod
def reset(cls):
cls.call_count = 0
def __torch_function__(self, func, types, args=(), kwargs=None):
CountingInt.call_count += 1
# Return valid result
if func.__name__ == 'pad':
return args[0]
return torch.tensor(42.0)
def __index__(self):
return self.value
import torch.nn.functional as F
x = torch.randn(2, 3)
obj = CountingInt(2)
CountingInt.reset()
# Mix regular ints with torch function object
_ = F.pad(x, [1, obj, 0, 0])
self.assertEqual(CountingInt.call_count, 1,
"torch_function should be called exactly once for mixed list")
def test_torch_function_empty_lists(self):
"""Test that empty lists work correctly"""
# This should work without calling any torch_function
x = torch.randn(1) # Single element tensor
# Functions that accept empty lists should still work
# torch.stack with empty list of tensors would fail,
# but empty size lists should work
result = x.view([]) # Empty list means scalar
self.assertEqual(result.shape, torch.Size([]),
"Empty list should work for size arguments")
def test_torch_function_not_first_in_list(self):
"""Test that torch_function is called even when object is not first in list"""
class IntLikeNotFirst:
"""Object with torch_function that won't be first in list"""
def __init__(self, value):
self.value = value
self.torch_function_called = False
def __torch_function__(self, func, types, args=(), kwargs=None):
self.torch_function_called = True
# Return input tensor for pad
return args[0]
def __index__(self):
return self.value
import torch.nn.functional as F
x = torch.randn(2, 3)
# Test with torch_function object as second item
obj_second = IntLikeNotFirst(2)
_ = F.pad(x, [1, obj_second, 0, 0])
self.assertTrue(obj_second.torch_function_called,
"torch_function should be called when object is second in list")
# Test with torch_function object as third item
obj_third = IntLikeNotFirst(1)
_ = F.pad(x, [1, 1, obj_third, 0])
self.assertTrue(obj_third.torch_function_called,
"torch_function should be called when object is third in list")
# Test with torch_function object as last item
obj_last = IntLikeNotFirst(1)
_ = F.pad(x, [1, 1, 1, obj_last])
self.assertTrue(obj_last.torch_function_called,
"torch_function should be called when object is last in list")
def test_torch_function_nested_tuple_getitem(self):
"""Test that torch_function is called with getitem for TF objects inside nested tuples"""
called_functions = []
class TorchFunctionObj:
"""Object with torch_function that tracks which functions are called"""
def __init__(self, value):
self.value = value
def __torch_function__(self, func, types, args=(), kwargs=None):
called_functions.append(func.__name__)
# For getitem, return the tensor unchanged
if func.__name__ == '__getitem__':
return args[0]
# Return a simple result for other functions
return torch.tensor(42.0)
def __index__(self):
return self.value
# Create a tensor to index
x = torch.randn(5, 5, 5)
# Create torch function objects - these will be INSIDE the nested structure
tf_obj1 = TorchFunctionObj(0)
tf_obj2 = TorchFunctionObj(1)
# Clear the called functions list
called_functions.clear()
# Test with tuple of tuple where TF objects are only on the INSIDE
# The outer structure is regular tuples, but inner elements have __torch_function__
# This tests the recursive detection logic added in the recent commit
x[(0, (tf_obj1, tf_obj2))]
# Assert that torch_function was called
self.assertTrue(len(called_functions) > 0,
"torch_function should be called for TF objects inside nested tuples")
# Assert that getitem was called, not size
self.assertIn('__getitem__', called_functions,
"getitem should be called for tuple indexing with torch function objects inside")
self.assertNotIn('size', called_functions,
"size should not be called - we should use getitem, not convert to advanced indexing")
def generate_tensor_like_override_tests(cls):
from torch.testing._internal.generated.annotated_fn_args import annotated_args
@ -1135,29 +1400,31 @@ class TestResolveName(TestCase):
)
class TestTorchFunctionWarning(TestCase):
def test_warn_on_invalid_torch_function_standalone_class(self):
def test_torch_function_standalone_class(self):
class StandaloneTorchFunctionClass:
def __torch_function__(self, *args, **kwargs):
pass
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# Return a simple tensor for testing
return torch.tensor(42.0)
a = StandaloneTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(a)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(a)
# Test that torch_function works without warnings
result1 = torch.nn.functional.dropout(a)
result2 = torch.abs(a)
self.assertEqual(result1, torch.tensor(42.0))
self.assertEqual(result2, torch.tensor(42.0))
def test_warn_on_invalid_torch_function_tensor_subclass(self):
def test_torch_function_tensor_subclass(self):
class TensorSubclassTorchFunctionClass(torch.Tensor):
def __torch_function__(self, *args, **kwargs):
pass
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# Return a simple tensor for testing
return torch.tensor(99.0)
b = TensorSubclassTorchFunctionClass()
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
# Function that handles torch_function on the python side
torch.nn.functional.dropout(b)
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
# Function that handles torch_function in C++
torch.abs(b)
# Test that torch_function works without warnings
result1 = torch.nn.functional.dropout(b)
result2 = torch.abs(b)
self.assertEqual(result1, torch.tensor(99.0))
self.assertEqual(result2, torch.tensor(99.0))
class TestDisabledUserWarnings(TestCase):
def test_no_implicit_user_warning_for_deprecated_functions(self):