mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Detect torch function in lists as well (#160256)
We basically follow the same pattern we do for tensor arguments. The major downside is we now have to traverse the entirety of the int list / etc where previously we didn't have. Benchmark suggests 2% regression for relevant things. Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160256 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
524b78d4f6
commit
9a1c5c0a07
@ -615,6 +615,271 @@ class TestTorchFunctionOverride(TestCase):
|
||||
|
||||
self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
|
||||
|
||||
def test_torch_function_in_lists(self):
|
||||
"""Test that __torch_function__ is called for objects inside lists"""
|
||||
|
||||
class IntLike:
|
||||
"""Object that can be used in int lists"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.torch_function_called = False
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
self.torch_function_called = True
|
||||
# Return a result that makes the operation succeed
|
||||
if func.__name__ == 'pad':
|
||||
# For pad, return the input with shape adjusted
|
||||
return args[0]
|
||||
elif func.__name__ == 'layer_norm':
|
||||
# For layer_norm, return normalized tensor
|
||||
return torch.ones_like(args[0])
|
||||
elif func.__name__ == 'tensordot':
|
||||
# For tensordot, return appropriate shape
|
||||
return torch.tensor(42.0)
|
||||
# Fallback
|
||||
return torch.tensor(42.0)
|
||||
|
||||
# Test with F.pad which takes int list
|
||||
import torch.nn.functional as F
|
||||
x = torch.randn(2, 3)
|
||||
obj = IntLike(1)
|
||||
|
||||
# pad takes [left, right, top, bottom] as padding
|
||||
_ = F.pad(x, [1, obj, 0, 0])
|
||||
self.assertTrue(obj.torch_function_called,
|
||||
"torch_function should be called for object in int list")
|
||||
|
||||
# Test multiple objects in list
|
||||
obj1 = IntLike(1)
|
||||
obj2 = IntLike(2)
|
||||
_ = F.pad(x, [obj1, obj2, 0, 0])
|
||||
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
|
||||
"torch_function should be called for at least one object")
|
||||
|
||||
def test_torch_function_in_float_lists(self):
|
||||
"""Test that __torch_function__ is called for objects inside float lists"""
|
||||
|
||||
class FloatLike:
|
||||
"""Object that can be used in float lists"""
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
self.torch_function_called = False
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
self.torch_function_called = True
|
||||
# Return appropriate result
|
||||
if func.__name__ == 'layer_norm':
|
||||
return torch.ones_like(args[0])
|
||||
return torch.tensor(42.0)
|
||||
|
||||
import torch.nn.functional as F
|
||||
x = torch.randn(2, 3, 4)
|
||||
obj = FloatLike(4.0)
|
||||
|
||||
# layer_norm takes normalized_shape as int/float list
|
||||
_ = F.layer_norm(x, [3, obj])
|
||||
self.assertTrue(obj.torch_function_called,
|
||||
"torch_function should be called for object in float list")
|
||||
|
||||
def test_torch_function_in_scalar_lists(self):
|
||||
"""Test that __torch_function__ is called for scalar objects inside lists"""
|
||||
|
||||
class ScalarLike:
|
||||
"""Object that can be used as a scalar in lists"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.torch_function_called = False
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
self.torch_function_called = True
|
||||
# Return a scalar tensor
|
||||
return torch.tensor(self.value)
|
||||
|
||||
def __float__(self):
|
||||
return float(self.value)
|
||||
|
||||
def __int__(self):
|
||||
return int(self.value)
|
||||
|
||||
# Test with a function that takes scalar lists
|
||||
# Using torch.as_tensor which can take scalar lists
|
||||
obj1 = ScalarLike(1.0)
|
||||
obj2 = ScalarLike(2.0)
|
||||
|
||||
# Create a tensor with scalar list containing torch function objects
|
||||
# Use a different operation that should trigger torch_function
|
||||
_ = torch.stack([obj1, obj2])
|
||||
self.assertTrue(obj1.torch_function_called or obj2.torch_function_called,
|
||||
"torch_function should be called for scalar objects in list")
|
||||
|
||||
def test_torch_function_precedence_in_lists(self):
|
||||
"""Test precedence when multiple torch function objects are in a list"""
|
||||
|
||||
call_order = []
|
||||
|
||||
class HighPriority:
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
call_order.append('high')
|
||||
# Delegate to lower priority
|
||||
return NotImplemented
|
||||
|
||||
class LowPriority:
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
call_order.append('low')
|
||||
# Return valid result
|
||||
if func.__name__ == 'pad':
|
||||
return args[0]
|
||||
return torch.tensor(42.0)
|
||||
|
||||
import torch.nn.functional as F
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
high = HighPriority()
|
||||
low = LowPriority()
|
||||
|
||||
# Test with both objects in list
|
||||
call_order.clear()
|
||||
_ = F.pad(x, [1, high, low, 0])
|
||||
|
||||
# High priority should be called first
|
||||
self.assertEqual(call_order[0], 'high',
|
||||
"Higher priority torch_function should be called first")
|
||||
self.assertEqual(call_order[1], 'low',
|
||||
"Lower priority torch_function should be called after NotImplemented")
|
||||
|
||||
def test_torch_function_mixed_lists(self):
|
||||
"""Test lists with mix of regular values and torch function objects"""
|
||||
|
||||
class CountingInt:
|
||||
call_count = 0
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
cls.call_count = 0
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
CountingInt.call_count += 1
|
||||
# Return valid result
|
||||
if func.__name__ == 'pad':
|
||||
return args[0]
|
||||
return torch.tensor(42.0)
|
||||
|
||||
def __index__(self):
|
||||
return self.value
|
||||
|
||||
import torch.nn.functional as F
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
obj = CountingInt(2)
|
||||
CountingInt.reset()
|
||||
|
||||
# Mix regular ints with torch function object
|
||||
_ = F.pad(x, [1, obj, 0, 0])
|
||||
|
||||
self.assertEqual(CountingInt.call_count, 1,
|
||||
"torch_function should be called exactly once for mixed list")
|
||||
|
||||
def test_torch_function_empty_lists(self):
|
||||
"""Test that empty lists work correctly"""
|
||||
|
||||
# This should work without calling any torch_function
|
||||
x = torch.randn(1) # Single element tensor
|
||||
|
||||
# Functions that accept empty lists should still work
|
||||
# torch.stack with empty list of tensors would fail,
|
||||
# but empty size lists should work
|
||||
result = x.view([]) # Empty list means scalar
|
||||
self.assertEqual(result.shape, torch.Size([]),
|
||||
"Empty list should work for size arguments")
|
||||
|
||||
def test_torch_function_not_first_in_list(self):
|
||||
"""Test that torch_function is called even when object is not first in list"""
|
||||
|
||||
class IntLikeNotFirst:
|
||||
"""Object with torch_function that won't be first in list"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.torch_function_called = False
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
self.torch_function_called = True
|
||||
# Return input tensor for pad
|
||||
return args[0]
|
||||
|
||||
def __index__(self):
|
||||
return self.value
|
||||
|
||||
import torch.nn.functional as F
|
||||
x = torch.randn(2, 3)
|
||||
|
||||
# Test with torch_function object as second item
|
||||
obj_second = IntLikeNotFirst(2)
|
||||
_ = F.pad(x, [1, obj_second, 0, 0])
|
||||
self.assertTrue(obj_second.torch_function_called,
|
||||
"torch_function should be called when object is second in list")
|
||||
|
||||
# Test with torch_function object as third item
|
||||
obj_third = IntLikeNotFirst(1)
|
||||
_ = F.pad(x, [1, 1, obj_third, 0])
|
||||
self.assertTrue(obj_third.torch_function_called,
|
||||
"torch_function should be called when object is third in list")
|
||||
|
||||
# Test with torch_function object as last item
|
||||
obj_last = IntLikeNotFirst(1)
|
||||
_ = F.pad(x, [1, 1, 1, obj_last])
|
||||
self.assertTrue(obj_last.torch_function_called,
|
||||
"torch_function should be called when object is last in list")
|
||||
|
||||
def test_torch_function_nested_tuple_getitem(self):
|
||||
"""Test that torch_function is called with getitem for TF objects inside nested tuples"""
|
||||
|
||||
called_functions = []
|
||||
|
||||
class TorchFunctionObj:
|
||||
"""Object with torch_function that tracks which functions are called"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
called_functions.append(func.__name__)
|
||||
# For getitem, return the tensor unchanged
|
||||
if func.__name__ == '__getitem__':
|
||||
return args[0]
|
||||
# Return a simple result for other functions
|
||||
return torch.tensor(42.0)
|
||||
|
||||
def __index__(self):
|
||||
return self.value
|
||||
|
||||
# Create a tensor to index
|
||||
x = torch.randn(5, 5, 5)
|
||||
|
||||
# Create torch function objects - these will be INSIDE the nested structure
|
||||
tf_obj1 = TorchFunctionObj(0)
|
||||
tf_obj2 = TorchFunctionObj(1)
|
||||
|
||||
# Clear the called functions list
|
||||
called_functions.clear()
|
||||
|
||||
# Test with tuple of tuple where TF objects are only on the INSIDE
|
||||
# The outer structure is regular tuples, but inner elements have __torch_function__
|
||||
# This tests the recursive detection logic added in the recent commit
|
||||
x[(0, (tf_obj1, tf_obj2))]
|
||||
|
||||
# Assert that torch_function was called
|
||||
self.assertTrue(len(called_functions) > 0,
|
||||
"torch_function should be called for TF objects inside nested tuples")
|
||||
|
||||
# Assert that getitem was called, not size
|
||||
self.assertIn('__getitem__', called_functions,
|
||||
"getitem should be called for tuple indexing with torch function objects inside")
|
||||
|
||||
self.assertNotIn('size', called_functions,
|
||||
"size should not be called - we should use getitem, not convert to advanced indexing")
|
||||
|
||||
|
||||
def generate_tensor_like_override_tests(cls):
|
||||
from torch.testing._internal.generated.annotated_fn_args import annotated_args
|
||||
@ -1135,29 +1400,31 @@ class TestResolveName(TestCase):
|
||||
)
|
||||
|
||||
class TestTorchFunctionWarning(TestCase):
|
||||
def test_warn_on_invalid_torch_function_standalone_class(self):
|
||||
def test_torch_function_standalone_class(self):
|
||||
class StandaloneTorchFunctionClass:
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
pass
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
# Return a simple tensor for testing
|
||||
return torch.tensor(42.0)
|
||||
a = StandaloneTorchFunctionClass()
|
||||
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function on the python side
|
||||
torch.nn.functional.dropout(a)
|
||||
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function in C++
|
||||
torch.abs(a)
|
||||
# Test that torch_function works without warnings
|
||||
result1 = torch.nn.functional.dropout(a)
|
||||
result2 = torch.abs(a)
|
||||
self.assertEqual(result1, torch.tensor(42.0))
|
||||
self.assertEqual(result2, torch.tensor(42.0))
|
||||
|
||||
def test_warn_on_invalid_torch_function_tensor_subclass(self):
|
||||
def test_torch_function_tensor_subclass(self):
|
||||
class TensorSubclassTorchFunctionClass(torch.Tensor):
|
||||
def __torch_function__(self, *args, **kwargs):
|
||||
pass
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
# Return a simple tensor for testing
|
||||
return torch.tensor(99.0)
|
||||
b = TensorSubclassTorchFunctionClass()
|
||||
with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function on the python side
|
||||
torch.nn.functional.dropout(b)
|
||||
with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
|
||||
# Function that handles torch_function in C++
|
||||
torch.abs(b)
|
||||
# Test that torch_function works without warnings
|
||||
result1 = torch.nn.functional.dropout(b)
|
||||
result2 = torch.abs(b)
|
||||
self.assertEqual(result1, torch.tensor(99.0))
|
||||
self.assertEqual(result2, torch.tensor(99.0))
|
||||
|
||||
class TestDisabledUserWarnings(TestCase):
|
||||
def test_no_implicit_user_warning_for_deprecated_functions(self):
|
||||
|
Reference in New Issue
Block a user