mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update to TorchFix 0.4.0 (#119424)
`torch.library.Library` updated to `torch.library._scoped_library` in files with many tests where it seems obvious to do, otherwise `noqa: TOR901` added - see https://github.com/pytorch/pytorch/pull/118318 for more context. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119424 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
5acd1f0f7d
commit
bd9db6a9c7
@ -63,10 +63,9 @@ class TestPythonRegistration(TestCase):
|
||||
# RuntimeError: impl("aten::neg", ...):
|
||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||
my_lib3 = Library("foo", "DEF")
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
del my_lib3
|
||||
with _scoped_library("foo", "DEF") as my_lib3:
|
||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||
|
||||
# Example 2
|
||||
def my_mul(*args, **kwargs):
|
||||
@ -92,12 +91,12 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
def test_error_if_fn_not_callable(self):
|
||||
with self.assertRaisesRegex(TypeError, "Input function is required to be a callable"):
|
||||
my_lib = Library("aten", "IMPL")
|
||||
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
|
||||
with _scoped_library("aten", "IMPL") as my_lib:
|
||||
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
|
||||
|
||||
def test_finalizer(self):
|
||||
impls_refcnt = sys.getrefcount(torch.library._impls)
|
||||
lib = Library(self.test_ns, "FRAGMENT")
|
||||
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||
lib.define("foo123(Tensor x) -> Tensor")
|
||||
|
||||
# 1 for `lib`, 1 for sys.getrefcount
|
||||
@ -142,12 +141,11 @@ class TestPythonRegistration(TestCase):
|
||||
run[0] = True
|
||||
return args[0].clone()
|
||||
|
||||
my_lib1 = Library("aten", "IMPL")
|
||||
my_lib1.impl('aten::sum', my_sum, "CPU")
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
self.assertTrue(run[0])
|
||||
del my_lib1
|
||||
with _scoped_library("aten", "IMPL") as my_lib1:
|
||||
my_lib1.impl('aten::sum', my_sum, "CPU")
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
self.assertTrue(run[0])
|
||||
# Validate that the old behavior is restored for sum
|
||||
self.assertEqual(torch.sum(x), torch.tensor(3))
|
||||
|
||||
@ -168,17 +166,16 @@ class TestPythonRegistration(TestCase):
|
||||
return jitted_where(*args, **kwargs)
|
||||
|
||||
# overriding where's cuda kernel with Jiterator generated kernel
|
||||
my_lib = Library("aten", "IMPL")
|
||||
my_lib.impl('aten::where.self', inverted_where, "CUDA")
|
||||
with _scoped_library("aten", "IMPL") as my_lib:
|
||||
my_lib.impl('aten::where.self', inverted_where, "CUDA")
|
||||
|
||||
device = 'cuda'
|
||||
cond = torch.tensor([True, True, False], device=device, dtype=torch.bool)
|
||||
x = torch.tensor([1, 2, 3], device=device)
|
||||
y = torch.tensor([-1, -2, -3], device=device)
|
||||
device = 'cuda'
|
||||
cond = torch.tensor([True, True, False], device=device, dtype=torch.bool)
|
||||
x = torch.tensor([1, 2, 3], device=device)
|
||||
y = torch.tensor([-1, -2, -3], device=device)
|
||||
|
||||
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
|
||||
self.assertTrue(CALLED[0])
|
||||
del my_lib
|
||||
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
|
||||
self.assertTrue(CALLED[0])
|
||||
|
||||
# behavior restored after deregistration
|
||||
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
|
||||
@ -199,13 +196,12 @@ class TestPythonRegistration(TestCase):
|
||||
return jitted_gelu(*args, **kwargs)
|
||||
|
||||
# overriding gelu's cuda kernel with Jiterator generated relu kernel
|
||||
my_lib = Library("aten", "IMPL")
|
||||
my_lib.impl('aten::gelu', fast_gelu, "CUDA")
|
||||
with _scoped_library("aten", "IMPL") as my_lib:
|
||||
my_lib.impl('aten::gelu', fast_gelu, "CUDA")
|
||||
|
||||
x = torch.rand([3, 3], device='cuda', dtype=torch.float)
|
||||
self.assertEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
||||
self.assertTrue(CALLED[0])
|
||||
del my_lib
|
||||
x = torch.rand([3, 3], device='cuda', dtype=torch.float)
|
||||
self.assertEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
||||
self.assertTrue(CALLED[0])
|
||||
|
||||
# behavior restored after deregistration
|
||||
self.assertNotEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
||||
@ -226,13 +222,12 @@ class TestPythonRegistration(TestCase):
|
||||
return jitted_exp(*args, **kwargs)
|
||||
|
||||
# overriding exp's cuda kernel with clipped_exp kernel
|
||||
my_lib = Library("aten", "IMPL")
|
||||
my_lib.impl('aten::exp', clipped_exp, "CUDA")
|
||||
with _scoped_library("aten", "IMPL") as my_lib:
|
||||
my_lib.impl('aten::exp', clipped_exp, "CUDA")
|
||||
|
||||
x = torch.tensor([0.0, 100.0], device='cuda', dtype=torch.float16)
|
||||
self.assertEqual(torch.exp(x), torch.tensor([1.0, 22026.4657948], dtype=torch.float16))
|
||||
self.assertTrue(CALLED[0])
|
||||
del my_lib
|
||||
x = torch.tensor([0.0, 100.0], device='cuda', dtype=torch.float16)
|
||||
self.assertEqual(torch.exp(x), torch.tensor([1.0, 22026.4657948], dtype=torch.float16))
|
||||
self.assertTrue(CALLED[0])
|
||||
|
||||
# behavior restored after deregistration
|
||||
self.assertEqual(torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16))
|
||||
@ -252,18 +247,17 @@ class TestPythonRegistration(TestCase):
|
||||
CALLED[0] = True
|
||||
return jitted_add(*args, **kwargs)
|
||||
|
||||
my_lib = Library("aten", "IMPL")
|
||||
my_lib.impl('aten::add.Tensor', buggy_add, "CUDA")
|
||||
with _scoped_library("aten", "IMPL") as my_lib:
|
||||
my_lib.impl('aten::add.Tensor', buggy_add, "CUDA")
|
||||
|
||||
x_cpu = torch.rand([3, 3], device='cpu')
|
||||
y_cpu = torch.rand([3], device='cpu')
|
||||
x_cpu = torch.rand([3, 3], device='cpu')
|
||||
y_cpu = torch.rand([3], device='cpu')
|
||||
|
||||
x_cuda = x_cpu.cuda()
|
||||
y_cuda = y_cpu.cuda()
|
||||
x_cuda = x_cpu.cuda()
|
||||
y_cuda = y_cpu.cuda()
|
||||
|
||||
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
|
||||
self.assertTrue(CALLED[0])
|
||||
del my_lib
|
||||
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
|
||||
self.assertTrue(CALLED[0])
|
||||
|
||||
# behavior restored after deregistration
|
||||
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
|
||||
@ -277,97 +271,80 @@ class TestPythonRegistration(TestCase):
|
||||
def test_extend_library_with_dispatch_key_arg(self):
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0].clone()
|
||||
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
|
||||
|
||||
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
||||
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
||||
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
||||
my_lib1.impl('sum', my_sum, "Conjugate")
|
||||
my_lib1.impl('aten::sum', my_sum)
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
del my_lib1
|
||||
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
|
||||
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
||||
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
||||
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
||||
my_lib1.impl('sum', my_sum, "Conjugate")
|
||||
my_lib1.impl('aten::sum', my_sum)
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(torch.sum(x), x)
|
||||
|
||||
def test_create_new_library(self) -> None:
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
||||
my_lib1.define("sum(Tensor self) -> Tensor")
|
||||
|
||||
my_lib1.define("sum(Tensor self) -> Tensor")
|
||||
|
||||
# Example 1
|
||||
@torch.library.impl(my_lib1, "sum", "CPU")
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0].clone()
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
op = getattr(torch.ops, self.test_ns).sum
|
||||
self.assertEqual(op(x), x)
|
||||
|
||||
my_lib2 = Library(self.test_ns, "IMPL")
|
||||
|
||||
# Example 2
|
||||
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
||||
def my_sum_zt(*args, **kwargs):
|
||||
if args[0]._is_zerotensor():
|
||||
return torch._efficientzerotensor(args[0].shape)
|
||||
else:
|
||||
# Example 1
|
||||
@torch.library.impl(my_lib1, "sum", "CPU")
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0].clone()
|
||||
|
||||
y = torch._efficientzerotensor(3)
|
||||
self.assertTrue(op(y)._is_zerotensor())
|
||||
self.assertEqual(op(x), x)
|
||||
x = torch.tensor([1, 2])
|
||||
op = getattr(torch.ops, self.test_ns).sum
|
||||
self.assertEqual(op(x), x)
|
||||
|
||||
del my_lib2
|
||||
del my_lib1
|
||||
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
|
||||
# Example 2
|
||||
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
||||
def my_sum_zt(*args, **kwargs):
|
||||
if args[0]._is_zerotensor():
|
||||
return torch._efficientzerotensor(args[0].shape)
|
||||
else:
|
||||
return args[0].clone()
|
||||
|
||||
y = torch._efficientzerotensor(3)
|
||||
self.assertTrue(op(y)._is_zerotensor())
|
||||
self.assertEqual(op(x), x)
|
||||
|
||||
def test_create_new_library_fragment_no_existing(self):
|
||||
my_lib = Library(self.test_ns, "FRAGMENT")
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
|
||||
my_lib.define("sum2(Tensor self) -> Tensor")
|
||||
|
||||
my_lib.define("sum2(Tensor self) -> Tensor")
|
||||
@torch.library.impl(my_lib, "sum2", "CPU")
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
@torch.library.impl(my_lib, "sum2", "CPU")
|
||||
def my_sum(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
||||
|
||||
del my_lib
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
||||
|
||||
def test_create_new_library_fragment_with_existing(self):
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
||||
# Create a fragment
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
|
||||
my_lib2.define("sum4(Tensor self) -> Tensor")
|
||||
|
||||
# Create a fragment
|
||||
my_lib2 = Library(self.test_ns, "FRAGMENT")
|
||||
@torch.library.impl(my_lib2, "sum4", "CPU")
|
||||
def my_sum4(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
my_lib2.define("sum4(Tensor self) -> Tensor")
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
||||
|
||||
@torch.library.impl(my_lib2, "sum4", "CPU")
|
||||
def my_sum4(*args, **kwargs):
|
||||
return args[0]
|
||||
# Create another fragment
|
||||
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
|
||||
my_lib3.define("sum3(Tensor self) -> Tensor")
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
||||
@torch.library.impl(my_lib3, "sum3", "CPU")
|
||||
def my_sum3(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
# Create another fragment
|
||||
my_lib3 = Library(self.test_ns, "FRAGMENT")
|
||||
|
||||
my_lib3.define("sum3(Tensor self) -> Tensor")
|
||||
|
||||
@torch.library.impl(my_lib3, "sum3", "CPU")
|
||||
def my_sum3(*args, **kwargs):
|
||||
return args[0]
|
||||
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
||||
|
||||
del my_lib1
|
||||
del my_lib2
|
||||
del my_lib3
|
||||
x = torch.tensor([1, 2])
|
||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
|
||||
def test_alias_analysis(self):
|
||||
def test_helper(alias_analysis=""):
|
||||
my_lib1 = Library(self.test_ns, "DEF")
|
||||
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
|
||||
|
||||
called = [0]
|
||||
|
||||
@ -388,11 +365,11 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
def test_error_for_unsupported_ns_or_kind(self) -> None:
|
||||
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
|
||||
my_lib1 = Library("myns", "BLA")
|
||||
my_lib1 = Library("myns", "BLA") # noqa: TOR901
|
||||
|
||||
for kind in ('DEF', 'FRAGMENT'):
|
||||
with self.assertRaisesRegex(ValueError, "reserved namespace"):
|
||||
my_lib1 = Library("prim", kind)
|
||||
my_lib1 = Library("prim", kind) # noqa: TOR901
|
||||
|
||||
def test_returning_symint(self) -> None:
|
||||
shape_env = ShapeEnv()
|
||||
@ -402,15 +379,15 @@ class TestPythonRegistration(TestCase):
|
||||
|
||||
s0, s1 = ft.shape
|
||||
|
||||
tlib = Library(self.test_ns, "DEF")
|
||||
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
||||
with _scoped_library(self.test_ns, "DEF") as tlib:
|
||||
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
||||
|
||||
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
||||
def sqsum(a: SymInt, b: SymInt):
|
||||
return a * a + b * b
|
||||
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
||||
def sqsum(a: SymInt, b: SymInt):
|
||||
return a * a + b * b
|
||||
|
||||
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
||||
out_val = shape_env.evaluate_expr(out.node.expr)
|
||||
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
||||
out_val = shape_env.evaluate_expr(out.node.expr)
|
||||
self.assertEqual(out_val, 13)
|
||||
|
||||
def test_register_functional_op_error_cases(self):
|
||||
@ -566,8 +543,7 @@ class TestPythonRegistration(TestCase):
|
||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||
|
||||
def test_register_fallthrough(self):
|
||||
try:
|
||||
my_lib = Library('aten', 'IMPL')
|
||||
with _scoped_library('aten', 'IMPL') as my_lib:
|
||||
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
|
||||
|
||||
a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
|
||||
@ -577,8 +553,6 @@ class TestPythonRegistration(TestCase):
|
||||
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
|
||||
# ops that don't have a fallthrough registered should not be affected
|
||||
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
|
||||
finally:
|
||||
del my_lib
|
||||
|
||||
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
||||
# default behavior should have been restored
|
||||
@ -694,13 +668,13 @@ $5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_fo
|
||||
print("woof")
|
||||
return torch.empty(())
|
||||
|
||||
my_lib = Library("my_lib", "DEF")
|
||||
my_lib.define("weird(Tensor?[] self) -> Tensor")
|
||||
my_lib.impl("weird", weird, "CPU")
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.ones(2, 2))
|
||||
log_input("x", x)
|
||||
torch.ops.my_lib.weird.default([None, x])
|
||||
with _scoped_library("my_lib", "DEF") as my_lib:
|
||||
my_lib.define("weird(Tensor?[] self) -> Tensor")
|
||||
my_lib.impl("weird", weird, "CPU")
|
||||
with capture_logs() as logs:
|
||||
x = LoggingTensor(torch.ones(2, 2))
|
||||
log_input("x", x)
|
||||
torch.ops.my_lib.weird.default([None, x])
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0: f32[2, 2] = input('x')
|
||||
@ -1485,28 +1459,29 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
||||
t.record_stream(s)
|
||||
|
||||
def test_return_stream(self) -> None:
|
||||
l_def = torch.library.Library("test_return_stream", "DEF")
|
||||
l_def.define("return_stream(Tensor self) -> Stream")
|
||||
l_impl = torch.library.Library("test_return_stream", "IMPL", "CPU")
|
||||
l_impl.impl("return_stream", lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2))
|
||||
with _scoped_library("test_return_stream", "DEF") as l_def:
|
||||
l_def.define("return_stream(Tensor self) -> Stream")
|
||||
with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
|
||||
l_impl.impl("return_stream",
|
||||
lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2))
|
||||
|
||||
class TestMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return torch.Stream(stream_id=1, device_index=2, device_type=3)
|
||||
class TestMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return torch.Stream(stream_id=1, device_index=2, device_type=3)
|
||||
|
||||
t = torch.tensor(5.)
|
||||
s = torch.ops.test_return_stream.return_stream(t)
|
||||
self.assertIsInstance(s, torch.Stream)
|
||||
self.assertEqual(s.stream_id, 0)
|
||||
self.assertEqual(s.device_index, 1)
|
||||
self.assertEqual(s.device_type, 2)
|
||||
t = torch.tensor(5.)
|
||||
s = torch.ops.test_return_stream.return_stream(t)
|
||||
self.assertIsInstance(s, torch.Stream)
|
||||
self.assertEqual(s.stream_id, 0)
|
||||
self.assertEqual(s.device_index, 1)
|
||||
self.assertEqual(s.device_type, 2)
|
||||
|
||||
with TestMode():
|
||||
s = torch.ops.test_return_stream.return_stream(t)
|
||||
self.assertIsInstance(s, torch.Stream)
|
||||
self.assertEqual(s.stream_id, 1)
|
||||
self.assertEqual(s.device_index, 2)
|
||||
self.assertEqual(s.device_type, 3)
|
||||
with TestMode():
|
||||
s = torch.ops.test_return_stream.return_stream(t)
|
||||
self.assertIsInstance(s, torch.Stream)
|
||||
self.assertEqual(s.stream_id, 1)
|
||||
self.assertEqual(s.device_index, 2)
|
||||
self.assertEqual(s.device_type, 3)
|
||||
|
||||
def test_subclass_autograd_device_check(self) -> None:
|
||||
class NonWrapperSubclass(torch.Tensor):
|
||||
|
Reference in New Issue
Block a user