mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Update to TorchFix 0.4.0 (#119424)
`torch.library.Library` updated to `torch.library._scoped_library` in files with many tests where it seems obvious to do, otherwise `noqa: TOR901` added - see https://github.com/pytorch/pytorch/pull/118318 for more context. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119424 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
5acd1f0f7d
commit
bd9db6a9c7
20
.flake8
20
.flake8
@ -2,7 +2,7 @@
|
|||||||
# NOTE: **Mirror any changes** to this file the [tool.ruff] config in pyproject.toml
|
# NOTE: **Mirror any changes** to this file the [tool.ruff] config in pyproject.toml
|
||||||
# before we can fully move to use ruff
|
# before we can fully move to use ruff
|
||||||
enable-extensions = G
|
enable-extensions = G
|
||||||
select = B,C,E,F,G,P,SIM1,T4,W,B9,TOR0,TOR1,TOR2
|
select = B,C,E,F,G,P,SIM1,T4,W,B9,TOR0,TOR1,TOR2,TOR9
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
# C408 ignored because we like the dict keyword argument syntax
|
# C408 ignored because we like the dict keyword argument syntax
|
||||||
# E501 is not flexible enough, we're using B950 instead
|
# E501 is not flexible enough, we're using B950 instead
|
||||||
@ -27,6 +27,9 @@ ignore =
|
|||||||
# TODO(kit1980): fix all TOR102 issues
|
# TODO(kit1980): fix all TOR102 issues
|
||||||
# `torch.load` without `weights_only` parameter is unsafe
|
# `torch.load` without `weights_only` parameter is unsafe
|
||||||
TOR102,
|
TOR102,
|
||||||
|
# TODO(kit1980): resolve all TOR003 issues
|
||||||
|
# pass `use_reentrant` explicitly to `checkpoint`.
|
||||||
|
TOR003
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
__init__.py: F401
|
__init__.py: F401
|
||||||
test/**: F821
|
test/**: F821
|
||||||
@ -36,6 +39,21 @@ per-file-ignores =
|
|||||||
torchgen/executorch/api/types/__init__.py: F401,F403
|
torchgen/executorch/api/types/__init__.py: F401,F403
|
||||||
test/dynamo/test_higher_order_ops.py: B950
|
test/dynamo/test_higher_order_ops.py: B950
|
||||||
torch/testing/_internal/dynamo_test_failures.py: B950
|
torch/testing/_internal/dynamo_test_failures.py: B950
|
||||||
|
# TOR901 is only for test, we want to ignore it for everything else.
|
||||||
|
# It's not easy to configure this without affecting other per-file-ignores,
|
||||||
|
# so we explicitly list every file where it's violated outside of test.
|
||||||
|
torch/__init__.py: F401,TOR901
|
||||||
|
torch/_custom_op/impl.py: TOR901
|
||||||
|
torch/_export/serde/upgrade.py: TOR901
|
||||||
|
torch/_functorch/vmap.py: TOR901
|
||||||
|
torch/_inductor/test_operators.py: TOR901
|
||||||
|
torch/_library/abstract_impl.py: TOR901
|
||||||
|
torch/_meta_registrations.py: TOR901
|
||||||
|
torch/_prims/__init__.py: F401,TOR901
|
||||||
|
torch/_prims/rng_prims.py: TOR901
|
||||||
|
torch/ao/quantization/fx/_decomposed.py: TOR901
|
||||||
|
torch/distributed/_functional_collectives.py: TOR901
|
||||||
|
torch/distributed/_spmd/data_parallel.py: TOR901
|
||||||
optional-ascii-coding = True
|
optional-ascii-coding = True
|
||||||
exclude =
|
exclude =
|
||||||
./.git,
|
./.git,
|
||||||
|
@ -46,7 +46,7 @@ init_command = [
|
|||||||
'mccabe==0.7.0',
|
'mccabe==0.7.0',
|
||||||
'pycodestyle==2.11.1',
|
'pycodestyle==2.11.1',
|
||||||
'pyflakes==3.1.0',
|
'pyflakes==3.1.0',
|
||||||
'torchfix==0.2.0',
|
'torchfix==0.4.0',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,4 +8,4 @@ flake8-pyi==20.5.0
|
|||||||
mccabe==0.6.1
|
mccabe==0.6.1
|
||||||
pycodestyle==2.6.0
|
pycodestyle==2.6.0
|
||||||
pyflakes==2.2.0
|
pyflakes==2.2.0
|
||||||
torchfix==0.2.0
|
torchfix==0.4.0
|
||||||
|
@ -256,7 +256,7 @@ def ddm_backward(grad: torch.Tensor) -> torch.Tensor:
|
|||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
dummy_lib = torch.library.Library("dummy", "DEF")
|
dummy_lib = torch.library.Library("dummy", "DEF") # noqa: TOR901
|
||||||
dummy_lib.define("ddm(Tensor x) -> Tensor")
|
dummy_lib.define("ddm(Tensor x) -> Tensor")
|
||||||
dummy_lib.impl("ddm", ddm, "CompositeExplicitAutograd")
|
dummy_lib.impl("ddm", ddm, "CompositeExplicitAutograd")
|
||||||
dummy_lib.define("ddm_backward(Tensor x) -> Tensor")
|
dummy_lib.define("ddm_backward(Tensor x) -> Tensor")
|
||||||
|
@ -25,7 +25,7 @@ def maybe_dupe_op(x):
|
|||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
lib = torch.library.Library("custom", "DEF")
|
lib = torch.library.Library("custom", "DEF") # noqa: TOR901
|
||||||
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
lib.define("maybe_dupe_op(Tensor a) -> (Tensor, Tensor)")
|
||||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
lib.impl("maybe_dupe_op", maybe_dupe_op, "CPU")
|
||||||
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
lib.impl("maybe_dupe_op", maybe_dupe_op, "Meta")
|
||||||
|
@ -39,7 +39,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
import torch.library
|
import torch.library
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
|
|
||||||
foo = Library("foo", "DEF")
|
foo = Library("foo", "DEF") # noqa: TOR901
|
||||||
foo.define("custom(Tensor self) -> Tensor")
|
foo.define("custom(Tensor self) -> Tensor")
|
||||||
|
|
||||||
# Dynamic shape data dependent operator. For static shape compilation, Dynamo
|
# Dynamic shape data dependent operator. For static shape compilation, Dynamo
|
||||||
|
@ -43,7 +43,7 @@ from torch.testing._internal.common_utils import (
|
|||||||
_orig_module_call = torch.nn.Module.__call__
|
_orig_module_call = torch.nn.Module.__call__
|
||||||
|
|
||||||
# Custom operator that only supports CPU and Meta
|
# Custom operator that only supports CPU and Meta
|
||||||
lib = torch.library.Library("test_sample", "DEF")
|
lib = torch.library.Library("test_sample", "DEF") # noqa: TOR901
|
||||||
lib.define("foo(Tensor self) -> Tensor")
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
lib.impl("foo", torch.sin, "CPU")
|
lib.impl("foo", torch.sin, "CPU")
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ class TestDeserialize(TestCase):
|
|||||||
|
|
||||||
def test_auto_functionalize(self):
|
def test_auto_functionalize(self):
|
||||||
try:
|
try:
|
||||||
lib = torch.library.Library("mylib", "FRAGMENT")
|
lib = torch.library.Library("mylib", "FRAGMENT") # noqa: TOR901
|
||||||
torch.library.define(
|
torch.library.define(
|
||||||
"mylib::foo1",
|
"mylib::foo1",
|
||||||
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor",
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor",
|
||||||
@ -522,7 +522,7 @@ class TestDeserialize(TestCase):
|
|||||||
def test_tensor_tensor_list(self):
|
def test_tensor_tensor_list(self):
|
||||||
try:
|
try:
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
lib = Library("_export", "FRAGMENT")
|
lib = Library("_export", "FRAGMENT") # noqa: TOR901
|
||||||
lib.define(
|
lib.define(
|
||||||
"_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])",
|
"_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])",
|
||||||
tags=torch.Tag.pt2_compliant_tag)
|
tags=torch.Tag.pt2_compliant_tag)
|
||||||
|
@ -17,9 +17,15 @@ class TestCustomLowering(TorchTestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super().setUpClass()
|
super().setUpClass()
|
||||||
cls.test_inductor_ops = torch.library.Library("test_inductor_ops", "DEF")
|
cls.test_inductor_ops = torch.library.Library( # noqa: TOR901
|
||||||
cls.impl_cuda = torch.library.Library("test_inductor_ops", "IMPL", "CUDA")
|
"test_inductor_ops", "DEF"
|
||||||
cls.impl_meta = torch.library.Library("test_inductor_ops", "IMPL", "Meta")
|
)
|
||||||
|
cls.impl_cuda = torch.library.Library( # noqa: TOR901
|
||||||
|
"test_inductor_ops", "IMPL", "CUDA"
|
||||||
|
)
|
||||||
|
cls.impl_meta = torch.library.Library( # noqa: TOR901
|
||||||
|
"test_inductor_ops", "IMPL", "Meta"
|
||||||
|
)
|
||||||
cls._register_jagged_to_padded_dense()
|
cls._register_jagged_to_padded_dense()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -123,7 +123,7 @@ skip_if_x86_mac = functools.partial(
|
|||||||
)
|
)
|
||||||
vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
|
vec_dtypes = [torch.float, torch.bfloat16, torch.float16]
|
||||||
|
|
||||||
libtest = torch.library.Library("test", "FRAGMENT")
|
libtest = torch.library.Library("test", "FRAGMENT") # noqa: TOR901
|
||||||
ids = set()
|
ids = set()
|
||||||
|
|
||||||
f32 = torch.float32
|
f32 = torch.float32
|
||||||
|
@ -275,31 +275,30 @@ class TestInductorDynamic(TestCase):
|
|||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
@torch._inductor.config.patch(implicit_fallbacks=True)
|
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||||
def test_item_to_inputs_kernel_nobreak(self, device):
|
def test_item_to_inputs_kernel_nobreak(self, device):
|
||||||
lib = torch.library.Library("test", "DEF")
|
with torch.library._scoped_library("test", "DEF") as lib:
|
||||||
|
try:
|
||||||
|
|
||||||
try:
|
@custom_ops.custom_op("test::foo")
|
||||||
|
def foo(x: torch.Tensor, y: int) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@custom_ops.custom_op("test::foo")
|
@custom_ops.impl("test::foo")
|
||||||
def foo(x: torch.Tensor, y: int) -> torch.Tensor:
|
def foo_impl(x: torch.Tensor, y: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
return x.clone()
|
||||||
|
|
||||||
@custom_ops.impl("test::foo")
|
@torch.library.impl_abstract("test::foo", lib=lib)
|
||||||
def foo_impl(x: torch.Tensor, y: int) -> torch.Tensor:
|
def foo_meta(x: torch.Tensor, y: int) -> torch.Tensor:
|
||||||
return x.clone()
|
return x.clone()
|
||||||
|
|
||||||
@torch.library.impl_abstract("test::foo", lib=lib)
|
@torch.compile(fullgraph=True)
|
||||||
def foo_meta(x: torch.Tensor, y: int) -> torch.Tensor:
|
def f(x, r):
|
||||||
return x.clone()
|
y = x.item()
|
||||||
|
return torch.ops.test.foo(r, y)
|
||||||
|
|
||||||
@torch.compile(fullgraph=True)
|
f(torch.tensor([3], device=device), torch.randn(10, device=device))
|
||||||
def f(x, r):
|
|
||||||
y = x.item()
|
|
||||||
return torch.ops.test.foo(r, y)
|
|
||||||
|
|
||||||
f(torch.tensor([3], device=device), torch.randn(10, device=device))
|
finally:
|
||||||
|
custom_ops._destroy("test::foo")
|
||||||
finally:
|
|
||||||
custom_ops._destroy("test::foo")
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch(
|
@torch._dynamo.config.patch(
|
||||||
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||||
@ -359,35 +358,34 @@ class TestInductorDynamic(TestCase):
|
|||||||
)
|
)
|
||||||
@torch._inductor.config.patch(implicit_fallbacks=True)
|
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||||
def test_dynamic_stride_nobreak(self, device):
|
def test_dynamic_stride_nobreak(self, device):
|
||||||
lib = torch.library.Library("test", "DEF")
|
with torch.library._scoped_library("test", "DEF") as lib:
|
||||||
|
try:
|
||||||
|
|
||||||
try:
|
@custom_ops.custom_op("test::foo")
|
||||||
|
def foo(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@custom_ops.custom_op("test::foo")
|
@custom_ops.impl("test::foo")
|
||||||
def foo(x: torch.Tensor) -> torch.Tensor:
|
def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
stride = x.item()
|
||||||
|
return torch.empty_strided((1,), (stride,), device=x.device)
|
||||||
|
|
||||||
@custom_ops.impl("test::foo")
|
@torch.library.impl_abstract("test::foo", lib=lib)
|
||||||
def foo_impl(x: torch.Tensor) -> torch.Tensor:
|
def foo_meta(x: torch.Tensor) -> torch.Tensor:
|
||||||
stride = x.item()
|
ctx = torch.library.get_ctx()
|
||||||
return torch.empty_strided((1,), (stride,), device=x.device)
|
stride = ctx.new_dynamic_size()
|
||||||
|
return torch.empty_strided((1,), (stride,), device=x.device)
|
||||||
|
|
||||||
@torch.library.impl_abstract("test::foo", lib=lib)
|
@torch.compile(fullgraph=True)
|
||||||
def foo_meta(x: torch.Tensor) -> torch.Tensor:
|
def f(x):
|
||||||
ctx = torch.library.get_ctx()
|
r = torch.ops.test.foo(x)
|
||||||
stride = ctx.new_dynamic_size()
|
y = r.stride(0)
|
||||||
return torch.empty_strided((1,), (stride,), device=x.device)
|
return torch.empty(y, device=x.device)
|
||||||
|
|
||||||
@torch.compile(fullgraph=True)
|
f(torch.tensor([3], device=device))
|
||||||
def f(x):
|
|
||||||
r = torch.ops.test.foo(x)
|
|
||||||
y = r.stride(0)
|
|
||||||
return torch.empty(y, device=x.device)
|
|
||||||
|
|
||||||
f(torch.tensor([3], device=device))
|
finally:
|
||||||
|
custom_ops._destroy("test::foo")
|
||||||
finally:
|
|
||||||
custom_ops._destroy("test::foo")
|
|
||||||
|
|
||||||
@torch._inductor.config.patch(disable_cpp_codegen=True)
|
@torch._inductor.config.patch(disable_cpp_codegen=True)
|
||||||
def test_floor(self):
|
def test_floor(self):
|
||||||
|
@ -1744,7 +1744,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||||||
|
|
||||||
def test_observer_callback(self):
|
def test_observer_callback(self):
|
||||||
from torch.library import Library, impl
|
from torch.library import Library, impl
|
||||||
test_lib = Library("test_int4", "DEF")
|
test_lib = Library("test_int4", "DEF") # noqa: TOR901
|
||||||
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
|
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
|
||||||
|
|
||||||
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
|
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
|
||||||
|
@ -5,7 +5,7 @@ import warnings
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.library import Library
|
from torch.library import _scoped_library, Library
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
instantiate_parametrized_tests,
|
instantiate_parametrized_tests,
|
||||||
parametrize,
|
parametrize,
|
||||||
@ -38,7 +38,7 @@ class TestAutogradFallback(TestCase):
|
|||||||
return getattr(getattr(torch.ops, self.test_ns), name).default
|
return getattr(getattr(torch.ops, self.test_ns), name).default
|
||||||
|
|
||||||
def get_lib(self):
|
def get_lib(self):
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||||
self.lib = lib
|
self.lib = lib
|
||||||
return lib
|
return lib
|
||||||
|
|
||||||
@ -146,166 +146,167 @@ class TestAutogradFallback(TestCase):
|
|||||||
# To be clear, none of these situations are OK and will lead
|
# To be clear, none of these situations are OK and will lead
|
||||||
# to other problems down the line. We're testing them because
|
# to other problems down the line. We're testing them because
|
||||||
# it is fairly common to actually do these things.
|
# it is fairly common to actually do these things.
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
lib.define("foo(Tensor self) -> Tensor")
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
lib.impl("foo", lambda x: x, "CPU")
|
lib.impl("foo", lambda x: x, "CPU")
|
||||||
op = self.get_op("foo")
|
op = self.get_op("foo")
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
y = op(x).sum()
|
y = op(x).sum()
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
y.backward()
|
y.backward()
|
||||||
self.assertEqual(x.grad, torch.ones_like(x))
|
self.assertEqual(x.grad, torch.ones_like(x))
|
||||||
|
|
||||||
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
|
lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
|
||||||
lib.impl("bar", lambda x: x, "CPU")
|
lib.impl("bar", lambda x: x, "CPU")
|
||||||
op = self.get_op("bar")
|
op = self.get_op("bar")
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
y = op(x).sum()
|
y = op(x).sum()
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
y.backward()
|
y.backward()
|
||||||
self.assertEqual(x.grad, torch.ones_like(x))
|
self.assertEqual(x.grad, torch.ones_like(x))
|
||||||
|
|
||||||
@parametrize("mode", ("nothing", "warn"))
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
def test_composite_registered_to_cpu(self, mode):
|
def test_composite_registered_to_cpu(self, mode):
|
||||||
with autograd_fallback_mode(mode):
|
with autograd_fallback_mode(mode):
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
lib.define("foo(Tensor self) -> Tensor")
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
|
lib.impl("foo", lambda x: x.sin().sum(), "CPU")
|
||||||
op = self.get_op("foo")
|
op = self.get_op("foo")
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
y = op(x)
|
y = op(x)
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
y.backward()
|
y.backward()
|
||||||
self.assertEqual(x.grad, x.cos())
|
self.assertEqual(x.grad, x.cos())
|
||||||
|
|
||||||
@parametrize("mode", ("nothing", "warn"))
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
def test_autograd_function_registered_to_cpu(self, mode):
|
def test_autograd_function_registered_to_cpu(self, mode):
|
||||||
with autograd_fallback_mode(mode):
|
with autograd_fallback_mode(mode):
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
lib.define("foo(Tensor self) -> Tensor")
|
lib.define("foo(Tensor self) -> Tensor")
|
||||||
|
|
||||||
class NumpySin(torch.autograd.Function):
|
class NumpySin(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x)
|
ctx.save_for_backward(x)
|
||||||
return torch.tensor(np.sin(x.cpu().numpy()))
|
return torch.tensor(np.sin(x.cpu().numpy()))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, gx):
|
def backward(ctx, gx):
|
||||||
(x,) = ctx.saved_tensors
|
(x,) = ctx.saved_tensors
|
||||||
return gx * x.cos()
|
return gx * x.cos()
|
||||||
|
|
||||||
lib.impl("foo", NumpySin.apply, "CPU")
|
lib.impl("foo", NumpySin.apply, "CPU")
|
||||||
op = self.get_op("foo")
|
op = self.get_op("foo")
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
y = op(x).sum()
|
y = op(x).sum()
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
y.backward()
|
y.backward()
|
||||||
self.assertEqual(x.grad, x.cos())
|
self.assertEqual(x.grad, x.cos())
|
||||||
|
|
||||||
@parametrize("mode", ("nothing", "warn"))
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
def test_inplace_autograd_function_registered_to_cpu(self, mode):
|
def test_inplace_autograd_function_registered_to_cpu(self, mode):
|
||||||
with autograd_fallback_mode(mode):
|
with autograd_fallback_mode(mode):
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
|
lib.define("foo(Tensor(a!) self) -> Tensor(a!)")
|
||||||
|
|
||||||
class NumpySin_(torch.autograd.Function):
|
class NumpySin_(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(x.clone())
|
ctx.save_for_backward(x.clone())
|
||||||
x_np = x.detach().numpy()
|
x_np = x.detach().numpy()
|
||||||
np.sin(x_np, out=x_np)
|
np.sin(x_np, out=x_np)
|
||||||
ctx.mark_dirty(x)
|
ctx.mark_dirty(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, gx):
|
def backward(ctx, gx):
|
||||||
(x,) = ctx.saved_tensors
|
(x,) = ctx.saved_tensors
|
||||||
return gx * x.cos()
|
return gx * x.cos()
|
||||||
|
|
||||||
lib.impl("foo", NumpySin_.apply, "CPU")
|
lib.impl("foo", NumpySin_.apply, "CPU")
|
||||||
op = self.get_op("foo")
|
op = self.get_op("foo")
|
||||||
|
|
||||||
x = torch.randn(3, requires_grad=True)
|
x = torch.randn(3, requires_grad=True)
|
||||||
z = x.clone()
|
z = x.clone()
|
||||||
w = z[0]
|
w = z[0]
|
||||||
y = op(w)
|
y = op(w)
|
||||||
|
|
||||||
expected = torch.zeros_like(x)
|
expected = torch.zeros_like(x)
|
||||||
expected[0] = x[0].cos()
|
expected[0] = x[0].cos()
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
(gx,) = torch.autograd.grad(y, x, torch.ones_like(y), retain_graph=True)
|
(gx,) = torch.autograd.grad(
|
||||||
self.assertEqual(gx, expected)
|
y, x, torch.ones_like(y), retain_graph=True
|
||||||
|
)
|
||||||
|
self.assertEqual(gx, expected)
|
||||||
|
|
||||||
expected = torch.ones_like(x)
|
expected = torch.ones_like(x)
|
||||||
expected[0] = x[0].cos()
|
expected[0] = x[0].cos()
|
||||||
with self._check_ctx(mode):
|
with self._check_ctx(mode):
|
||||||
(gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
|
(gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
|
||||||
self.assertEqual(gx, expected)
|
self.assertEqual(gx, expected)
|
||||||
|
|
||||||
@parametrize("mode", ("nothing", "warn"))
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
|
def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
|
||||||
# We don't do anything special (that is, we don't rebase history).
|
# We don't do anything special (that is, we don't rebase history).
|
||||||
# See NOTE [autograd fallback and in-place operations] for why
|
# See NOTE [autograd fallback and in-place operations] for why
|
||||||
with autograd_fallback_mode(mode):
|
with autograd_fallback_mode(mode):
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as lib:
|
||||||
|
# Correct usage of (a!)
|
||||||
|
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
||||||
|
|
||||||
# Correct usage of (a!)
|
def foo_impl(x, y):
|
||||||
lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
return x
|
||||||
|
|
||||||
def foo_impl(x, y):
|
lib.impl("foo", foo_impl, "CPU")
|
||||||
x_d = x.detach()
|
foo = self.get_op("foo")
|
||||||
y = y.detach()
|
|
||||||
x_d.add_(y)
|
|
||||||
return x
|
|
||||||
|
|
||||||
lib.impl("foo", foo_impl, "CPU")
|
# Incorrect usage of (a!): user doesn't return tensor as-is
|
||||||
foo = self.get_op("foo")
|
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
||||||
|
|
||||||
# Incorrect usage of (a!): user doesn't return tensor as-is
|
def bar_impl(x, y):
|
||||||
lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
return x_d.clone()
|
||||||
|
|
||||||
def bar_impl(x, y):
|
lib.impl("bar", bar_impl, "CPU")
|
||||||
x_d = x.detach()
|
bar = self.get_op("bar")
|
||||||
y = y.detach()
|
|
||||||
x_d.add_(y)
|
|
||||||
return x_d.clone()
|
|
||||||
|
|
||||||
lib.impl("bar", bar_impl, "CPU")
|
# User mutated input tensor but didn't return it.
|
||||||
bar = self.get_op("bar")
|
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
|
||||||
|
|
||||||
# User mutated input tensor but didn't return it.
|
def baz_impl(x, y):
|
||||||
lib.define("baz(Tensor(a!) self, Tensor other) -> ()")
|
x_d = x.detach()
|
||||||
|
y = y.detach()
|
||||||
|
x_d.add_(y)
|
||||||
|
|
||||||
def baz_impl(x, y):
|
lib.impl("baz", baz_impl, "CPU")
|
||||||
x_d = x.detach()
|
baz = self.get_op("baz")
|
||||||
y = y.detach()
|
|
||||||
x_d.add_(y)
|
|
||||||
|
|
||||||
lib.impl("baz", baz_impl, "CPU")
|
# Test in-place on non-view
|
||||||
baz = self.get_op("baz")
|
for op in (foo, bar, baz):
|
||||||
|
x = torch.randn(3)
|
||||||
|
y = torch.randn(3, requires_grad=True)
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||||
|
z = x.clone()
|
||||||
|
op(z, y)
|
||||||
|
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
|
||||||
|
|
||||||
# Test in-place on non-view
|
# Test in-place on view
|
||||||
for op in (foo, bar, baz):
|
for op in (foo, bar, baz):
|
||||||
x = torch.randn(3)
|
x = torch.randn(3)
|
||||||
y = torch.randn(3, requires_grad=True)
|
y = torch.randn(3, requires_grad=True)
|
||||||
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
||||||
z = x.clone()
|
z = x[:]
|
||||||
op(z, y)
|
op(z, y)
|
||||||
torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)
|
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
|
||||||
|
|
||||||
# Test in-place on view
|
|
||||||
for op in (foo, bar, baz):
|
|
||||||
x = torch.randn(3)
|
|
||||||
y = torch.randn(3, requires_grad=True)
|
|
||||||
with self.assertRaisesRegex(RuntimeError, "does not require grad"):
|
|
||||||
z = x[:]
|
|
||||||
op(z, y)
|
|
||||||
torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)
|
|
||||||
|
|
||||||
@parametrize("mode", ("nothing", "warn"))
|
@parametrize("mode", ("nothing", "warn"))
|
||||||
def test_post_autograd_returns_leaf(self, mode):
|
def test_post_autograd_returns_leaf(self, mode):
|
||||||
|
@ -46,7 +46,7 @@ class CustomOpTestCaseBase(TestCase):
|
|||||||
return getattr(torch.ops, self.test_ns)
|
return getattr(torch.ops, self.test_ns)
|
||||||
|
|
||||||
def lib(self):
|
def lib(self):
|
||||||
result = torch.library.Library(self.test_ns, "FRAGMENT")
|
result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||||
self.libraries.append(result)
|
self.libraries.append(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ class FakeTensorTest(TestCase):
|
|||||||
def test_custom_op_fallback(self):
|
def test_custom_op_fallback(self):
|
||||||
from torch.library import Library, impl
|
from torch.library import Library, impl
|
||||||
|
|
||||||
test_lib = Library("my_test_op", "DEF")
|
test_lib = Library("my_test_op", "DEF") # noqa: TOR901
|
||||||
test_lib.define('foo(Tensor self) -> Tensor')
|
test_lib.define('foo(Tensor self) -> Tensor')
|
||||||
|
|
||||||
@impl(test_lib, 'foo', 'CPU')
|
@impl(test_lib, 'foo', 'CPU')
|
||||||
|
@ -743,7 +743,7 @@ class MultiOutputWithWithInvalidMatches:
|
|||||||
class QuantizationFp8Pattern:
|
class QuantizationFp8Pattern:
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup(cls):
|
def setup(cls):
|
||||||
cls.quantization = torch.library.Library("fp8_quantization", "DEF")
|
cls.quantization = torch.library.Library("fp8_quantization", "DEF") # noqa: TOR901
|
||||||
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
||||||
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
|
||||||
|
|
||||||
|
@ -1325,26 +1325,22 @@ class TestMeta(TestCase):
|
|||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
def test_meta_autograd_no_error(self):
|
def test_meta_autograd_no_error(self):
|
||||||
lib = torch.library.Library("meta_test", "DEF")
|
with torch.library._scoped_library("meta_test", "DEF") as lib:
|
||||||
impl_cpu = torch.library.Library("meta_test", "IMPL", "CPU")
|
with torch.library._scoped_library("meta_test", "IMPL", "CPU") as impl_cpu:
|
||||||
impl_meta = torch.library.Library("meta_test", "IMPL", "Meta")
|
with torch.library._scoped_library("meta_test", "IMPL", "Meta") as impl_meta:
|
||||||
|
def foo_impl(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
def foo_impl(x):
|
lib.define("foo(Tensor a) -> Tensor")
|
||||||
return x + 1
|
impl_meta.impl("foo", foo_impl)
|
||||||
|
impl_cpu.impl("foo", foo_impl)
|
||||||
|
|
||||||
lib.define("foo(Tensor a) -> Tensor")
|
a = torch.ones(2, device='meta')
|
||||||
impl_meta.impl("foo", foo_impl)
|
# The point of the test is that this should not error:
|
||||||
impl_cpu.impl("foo", foo_impl)
|
# We have a fallthrough kernel registered to the AutogradMeta
|
||||||
|
# key for custom ops, so it's fine that `foo()` doesn't have
|
||||||
a = torch.ones(2, device='meta')
|
# an autograd kernel.
|
||||||
# The point of the test is that this should not error:
|
b = torch.ops.meta_test.foo.default(a)
|
||||||
# We have a fallthrough kernel registered to the AutogradMeta
|
|
||||||
# key for custom ops, so it's fine that `foo()` doesn't have
|
|
||||||
# an autograd kernel.
|
|
||||||
b = torch.ops.meta_test.foo.default(a)
|
|
||||||
del impl_meta
|
|
||||||
del impl_cpu
|
|
||||||
del lib
|
|
||||||
|
|
||||||
def test_huber_loss_backward(self):
|
def test_huber_loss_backward(self):
|
||||||
inps = [torch.rand(2**52, device='meta') for _ in range(3)]
|
inps = [torch.rand(2**52, device='meta') for _ in range(3)]
|
||||||
|
@ -955,7 +955,7 @@ class TestSymbolicTracing(TestCase):
|
|||||||
import torch.library
|
import torch.library
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
|
|
||||||
foo = Library("foo", "DEF")
|
foo = Library("foo", "DEF") # noqa: TOR901
|
||||||
foo.define("foo(Tensor self) -> Tensor")
|
foo.define("foo(Tensor self) -> Tensor")
|
||||||
|
|
||||||
# Operator where meta and cpu disagree on strides
|
# Operator where meta and cpu disagree on strides
|
||||||
|
@ -63,10 +63,9 @@ class TestPythonRegistration(TestCase):
|
|||||||
# RuntimeError: impl("aten::neg", ...):
|
# RuntimeError: impl("aten::neg", ...):
|
||||||
# Explicitly provided namespace (aten) in operator name does not match ...
|
# Explicitly provided namespace (aten) in operator name does not match ...
|
||||||
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
with self.assertRaisesRegex(RuntimeError, "operator name does not match namespace"):
|
||||||
my_lib3 = Library("foo", "DEF")
|
with _scoped_library("foo", "DEF") as my_lib3:
|
||||||
my_lib3.define("neg(Tensor self) -> Tensor")
|
my_lib3.define("neg(Tensor self) -> Tensor")
|
||||||
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
|
||||||
del my_lib3
|
|
||||||
|
|
||||||
# Example 2
|
# Example 2
|
||||||
def my_mul(*args, **kwargs):
|
def my_mul(*args, **kwargs):
|
||||||
@ -92,12 +91,12 @@ class TestPythonRegistration(TestCase):
|
|||||||
|
|
||||||
def test_error_if_fn_not_callable(self):
|
def test_error_if_fn_not_callable(self):
|
||||||
with self.assertRaisesRegex(TypeError, "Input function is required to be a callable"):
|
with self.assertRaisesRegex(TypeError, "Input function is required to be a callable"):
|
||||||
my_lib = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib:
|
||||||
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
|
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
|
||||||
|
|
||||||
def test_finalizer(self):
|
def test_finalizer(self):
|
||||||
impls_refcnt = sys.getrefcount(torch.library._impls)
|
impls_refcnt = sys.getrefcount(torch.library._impls)
|
||||||
lib = Library(self.test_ns, "FRAGMENT")
|
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
|
||||||
lib.define("foo123(Tensor x) -> Tensor")
|
lib.define("foo123(Tensor x) -> Tensor")
|
||||||
|
|
||||||
# 1 for `lib`, 1 for sys.getrefcount
|
# 1 for `lib`, 1 for sys.getrefcount
|
||||||
@ -142,12 +141,11 @@ class TestPythonRegistration(TestCase):
|
|||||||
run[0] = True
|
run[0] = True
|
||||||
return args[0].clone()
|
return args[0].clone()
|
||||||
|
|
||||||
my_lib1 = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib1:
|
||||||
my_lib1.impl('aten::sum', my_sum, "CPU")
|
my_lib1.impl('aten::sum', my_sum, "CPU")
|
||||||
x = torch.tensor([1, 2])
|
x = torch.tensor([1, 2])
|
||||||
self.assertEqual(torch.sum(x), x)
|
self.assertEqual(torch.sum(x), x)
|
||||||
self.assertTrue(run[0])
|
self.assertTrue(run[0])
|
||||||
del my_lib1
|
|
||||||
# Validate that the old behavior is restored for sum
|
# Validate that the old behavior is restored for sum
|
||||||
self.assertEqual(torch.sum(x), torch.tensor(3))
|
self.assertEqual(torch.sum(x), torch.tensor(3))
|
||||||
|
|
||||||
@ -168,17 +166,16 @@ class TestPythonRegistration(TestCase):
|
|||||||
return jitted_where(*args, **kwargs)
|
return jitted_where(*args, **kwargs)
|
||||||
|
|
||||||
# overriding where's cuda kernel with Jiterator generated kernel
|
# overriding where's cuda kernel with Jiterator generated kernel
|
||||||
my_lib = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib:
|
||||||
my_lib.impl('aten::where.self', inverted_where, "CUDA")
|
my_lib.impl('aten::where.self', inverted_where, "CUDA")
|
||||||
|
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
cond = torch.tensor([True, True, False], device=device, dtype=torch.bool)
|
cond = torch.tensor([True, True, False], device=device, dtype=torch.bool)
|
||||||
x = torch.tensor([1, 2, 3], device=device)
|
x = torch.tensor([1, 2, 3], device=device)
|
||||||
y = torch.tensor([-1, -2, -3], device=device)
|
y = torch.tensor([-1, -2, -3], device=device)
|
||||||
|
|
||||||
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
|
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
|
||||||
self.assertTrue(CALLED[0])
|
self.assertTrue(CALLED[0])
|
||||||
del my_lib
|
|
||||||
|
|
||||||
# behavior restored after deregistration
|
# behavior restored after deregistration
|
||||||
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
|
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
|
||||||
@ -199,13 +196,12 @@ class TestPythonRegistration(TestCase):
|
|||||||
return jitted_gelu(*args, **kwargs)
|
return jitted_gelu(*args, **kwargs)
|
||||||
|
|
||||||
# overriding gelu's cuda kernel with Jiterator generated relu kernel
|
# overriding gelu's cuda kernel with Jiterator generated relu kernel
|
||||||
my_lib = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib:
|
||||||
my_lib.impl('aten::gelu', fast_gelu, "CUDA")
|
my_lib.impl('aten::gelu', fast_gelu, "CUDA")
|
||||||
|
|
||||||
x = torch.rand([3, 3], device='cuda', dtype=torch.float)
|
x = torch.rand([3, 3], device='cuda', dtype=torch.float)
|
||||||
self.assertEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
self.assertEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
||||||
self.assertTrue(CALLED[0])
|
self.assertTrue(CALLED[0])
|
||||||
del my_lib
|
|
||||||
|
|
||||||
# behavior restored after deregistration
|
# behavior restored after deregistration
|
||||||
self.assertNotEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
self.assertNotEqual(torch.nn.functional.gelu(x), torch.nn.functional.relu(x))
|
||||||
@ -226,13 +222,12 @@ class TestPythonRegistration(TestCase):
|
|||||||
return jitted_exp(*args, **kwargs)
|
return jitted_exp(*args, **kwargs)
|
||||||
|
|
||||||
# overriding exp's cuda kernel with clipped_exp kernel
|
# overriding exp's cuda kernel with clipped_exp kernel
|
||||||
my_lib = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib:
|
||||||
my_lib.impl('aten::exp', clipped_exp, "CUDA")
|
my_lib.impl('aten::exp', clipped_exp, "CUDA")
|
||||||
|
|
||||||
x = torch.tensor([0.0, 100.0], device='cuda', dtype=torch.float16)
|
x = torch.tensor([0.0, 100.0], device='cuda', dtype=torch.float16)
|
||||||
self.assertEqual(torch.exp(x), torch.tensor([1.0, 22026.4657948], dtype=torch.float16))
|
self.assertEqual(torch.exp(x), torch.tensor([1.0, 22026.4657948], dtype=torch.float16))
|
||||||
self.assertTrue(CALLED[0])
|
self.assertTrue(CALLED[0])
|
||||||
del my_lib
|
|
||||||
|
|
||||||
# behavior restored after deregistration
|
# behavior restored after deregistration
|
||||||
self.assertEqual(torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16))
|
self.assertEqual(torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16))
|
||||||
@ -252,18 +247,17 @@ class TestPythonRegistration(TestCase):
|
|||||||
CALLED[0] = True
|
CALLED[0] = True
|
||||||
return jitted_add(*args, **kwargs)
|
return jitted_add(*args, **kwargs)
|
||||||
|
|
||||||
my_lib = Library("aten", "IMPL")
|
with _scoped_library("aten", "IMPL") as my_lib:
|
||||||
my_lib.impl('aten::add.Tensor', buggy_add, "CUDA")
|
my_lib.impl('aten::add.Tensor', buggy_add, "CUDA")
|
||||||
|
|
||||||
x_cpu = torch.rand([3, 3], device='cpu')
|
x_cpu = torch.rand([3, 3], device='cpu')
|
||||||
y_cpu = torch.rand([3], device='cpu')
|
y_cpu = torch.rand([3], device='cpu')
|
||||||
|
|
||||||
x_cuda = x_cpu.cuda()
|
x_cuda = x_cpu.cuda()
|
||||||
y_cuda = y_cpu.cuda()
|
y_cuda = y_cpu.cuda()
|
||||||
|
|
||||||
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
|
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
|
||||||
self.assertTrue(CALLED[0])
|
self.assertTrue(CALLED[0])
|
||||||
del my_lib
|
|
||||||
|
|
||||||
# behavior restored after deregistration
|
# behavior restored after deregistration
|
||||||
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
|
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
|
||||||
@ -277,97 +271,80 @@ class TestPythonRegistration(TestCase):
|
|||||||
def test_extend_library_with_dispatch_key_arg(self):
|
def test_extend_library_with_dispatch_key_arg(self):
|
||||||
def my_sum(*args, **kwargs):
|
def my_sum(*args, **kwargs):
|
||||||
return args[0].clone()
|
return args[0].clone()
|
||||||
my_lib1 = Library("aten", "IMPL", dispatch_key="CPU")
|
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
|
||||||
|
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
||||||
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
|
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
||||||
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
|
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
||||||
with self.assertRaisesRegex(RuntimeError, "inconsistent with the dispatch key"):
|
my_lib1.impl('sum', my_sum, "Conjugate")
|
||||||
my_lib1.impl('sum', my_sum, "Conjugate")
|
my_lib1.impl('aten::sum', my_sum)
|
||||||
my_lib1.impl('aten::sum', my_sum)
|
x = torch.tensor([1, 2])
|
||||||
x = torch.tensor([1, 2])
|
self.assertEqual(torch.sum(x), x)
|
||||||
self.assertEqual(torch.sum(x), x)
|
|
||||||
del my_lib1
|
|
||||||
|
|
||||||
def test_create_new_library(self) -> None:
|
def test_create_new_library(self) -> None:
|
||||||
my_lib1 = Library(self.test_ns, "DEF")
|
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
||||||
|
my_lib1.define("sum(Tensor self) -> Tensor")
|
||||||
|
|
||||||
my_lib1.define("sum(Tensor self) -> Tensor")
|
# Example 1
|
||||||
|
@torch.library.impl(my_lib1, "sum", "CPU")
|
||||||
# Example 1
|
def my_sum(*args, **kwargs):
|
||||||
@torch.library.impl(my_lib1, "sum", "CPU")
|
|
||||||
def my_sum(*args, **kwargs):
|
|
||||||
return args[0].clone()
|
|
||||||
|
|
||||||
x = torch.tensor([1, 2])
|
|
||||||
op = getattr(torch.ops, self.test_ns).sum
|
|
||||||
self.assertEqual(op(x), x)
|
|
||||||
|
|
||||||
my_lib2 = Library(self.test_ns, "IMPL")
|
|
||||||
|
|
||||||
# Example 2
|
|
||||||
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
|
||||||
def my_sum_zt(*args, **kwargs):
|
|
||||||
if args[0]._is_zerotensor():
|
|
||||||
return torch._efficientzerotensor(args[0].shape)
|
|
||||||
else:
|
|
||||||
return args[0].clone()
|
return args[0].clone()
|
||||||
|
|
||||||
y = torch._efficientzerotensor(3)
|
x = torch.tensor([1, 2])
|
||||||
self.assertTrue(op(y)._is_zerotensor())
|
op = getattr(torch.ops, self.test_ns).sum
|
||||||
self.assertEqual(op(x), x)
|
self.assertEqual(op(x), x)
|
||||||
|
|
||||||
del my_lib2
|
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
|
||||||
del my_lib1
|
# Example 2
|
||||||
|
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
|
||||||
|
def my_sum_zt(*args, **kwargs):
|
||||||
|
if args[0]._is_zerotensor():
|
||||||
|
return torch._efficientzerotensor(args[0].shape)
|
||||||
|
else:
|
||||||
|
return args[0].clone()
|
||||||
|
|
||||||
|
y = torch._efficientzerotensor(3)
|
||||||
|
self.assertTrue(op(y)._is_zerotensor())
|
||||||
|
self.assertEqual(op(x), x)
|
||||||
|
|
||||||
def test_create_new_library_fragment_no_existing(self):
|
def test_create_new_library_fragment_no_existing(self):
|
||||||
my_lib = Library(self.test_ns, "FRAGMENT")
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
|
||||||
|
my_lib.define("sum2(Tensor self) -> Tensor")
|
||||||
|
|
||||||
my_lib.define("sum2(Tensor self) -> Tensor")
|
@torch.library.impl(my_lib, "sum2", "CPU")
|
||||||
|
def my_sum(*args, **kwargs):
|
||||||
|
return args[0]
|
||||||
|
|
||||||
@torch.library.impl(my_lib, "sum2", "CPU")
|
x = torch.tensor([1, 2])
|
||||||
def my_sum(*args, **kwargs):
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
||||||
return args[0]
|
|
||||||
|
|
||||||
x = torch.tensor([1, 2])
|
|
||||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
|
|
||||||
|
|
||||||
del my_lib
|
|
||||||
|
|
||||||
def test_create_new_library_fragment_with_existing(self):
|
def test_create_new_library_fragment_with_existing(self):
|
||||||
my_lib1 = Library(self.test_ns, "DEF")
|
with _scoped_library(self.test_ns, "DEF") as my_lib1:
|
||||||
|
# Create a fragment
|
||||||
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
|
||||||
|
my_lib2.define("sum4(Tensor self) -> Tensor")
|
||||||
|
|
||||||
# Create a fragment
|
@torch.library.impl(my_lib2, "sum4", "CPU")
|
||||||
my_lib2 = Library(self.test_ns, "FRAGMENT")
|
def my_sum4(*args, **kwargs):
|
||||||
|
return args[0]
|
||||||
|
|
||||||
my_lib2.define("sum4(Tensor self) -> Tensor")
|
x = torch.tensor([1, 2])
|
||||||
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
||||||
|
|
||||||
@torch.library.impl(my_lib2, "sum4", "CPU")
|
# Create another fragment
|
||||||
def my_sum4(*args, **kwargs):
|
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
|
||||||
return args[0]
|
my_lib3.define("sum3(Tensor self) -> Tensor")
|
||||||
|
|
||||||
x = torch.tensor([1, 2])
|
@torch.library.impl(my_lib3, "sum3", "CPU")
|
||||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
|
def my_sum3(*args, **kwargs):
|
||||||
|
return args[0]
|
||||||
|
|
||||||
# Create another fragment
|
x = torch.tensor([1, 2])
|
||||||
my_lib3 = Library(self.test_ns, "FRAGMENT")
|
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
||||||
|
|
||||||
my_lib3.define("sum3(Tensor self) -> Tensor")
|
|
||||||
|
|
||||||
@torch.library.impl(my_lib3, "sum3", "CPU")
|
|
||||||
def my_sum3(*args, **kwargs):
|
|
||||||
return args[0]
|
|
||||||
|
|
||||||
x = torch.tensor([1, 2])
|
|
||||||
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
|
|
||||||
|
|
||||||
del my_lib1
|
|
||||||
del my_lib2
|
|
||||||
del my_lib3
|
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
|
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
|
||||||
def test_alias_analysis(self):
|
def test_alias_analysis(self):
|
||||||
def test_helper(alias_analysis=""):
|
def test_helper(alias_analysis=""):
|
||||||
my_lib1 = Library(self.test_ns, "DEF")
|
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
|
||||||
|
|
||||||
called = [0]
|
called = [0]
|
||||||
|
|
||||||
@ -388,11 +365,11 @@ class TestPythonRegistration(TestCase):
|
|||||||
|
|
||||||
def test_error_for_unsupported_ns_or_kind(self) -> None:
|
def test_error_for_unsupported_ns_or_kind(self) -> None:
|
||||||
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
|
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
|
||||||
my_lib1 = Library("myns", "BLA")
|
my_lib1 = Library("myns", "BLA") # noqa: TOR901
|
||||||
|
|
||||||
for kind in ('DEF', 'FRAGMENT'):
|
for kind in ('DEF', 'FRAGMENT'):
|
||||||
with self.assertRaisesRegex(ValueError, "reserved namespace"):
|
with self.assertRaisesRegex(ValueError, "reserved namespace"):
|
||||||
my_lib1 = Library("prim", kind)
|
my_lib1 = Library("prim", kind) # noqa: TOR901
|
||||||
|
|
||||||
def test_returning_symint(self) -> None:
|
def test_returning_symint(self) -> None:
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
@ -402,15 +379,15 @@ class TestPythonRegistration(TestCase):
|
|||||||
|
|
||||||
s0, s1 = ft.shape
|
s0, s1 = ft.shape
|
||||||
|
|
||||||
tlib = Library(self.test_ns, "DEF")
|
with _scoped_library(self.test_ns, "DEF") as tlib:
|
||||||
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
|
||||||
|
|
||||||
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
|
||||||
def sqsum(a: SymInt, b: SymInt):
|
def sqsum(a: SymInt, b: SymInt):
|
||||||
return a * a + b * b
|
return a * a + b * b
|
||||||
|
|
||||||
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
|
||||||
out_val = shape_env.evaluate_expr(out.node.expr)
|
out_val = shape_env.evaluate_expr(out.node.expr)
|
||||||
self.assertEqual(out_val, 13)
|
self.assertEqual(out_val, 13)
|
||||||
|
|
||||||
def test_register_functional_op_error_cases(self):
|
def test_register_functional_op_error_cases(self):
|
||||||
@ -566,8 +543,7 @@ class TestPythonRegistration(TestCase):
|
|||||||
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
|
||||||
|
|
||||||
def test_register_fallthrough(self):
|
def test_register_fallthrough(self):
|
||||||
try:
|
with _scoped_library('aten', 'IMPL') as my_lib:
|
||||||
my_lib = Library('aten', 'IMPL')
|
|
||||||
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
|
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
|
||||||
|
|
||||||
a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
|
a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
|
||||||
@ -577,8 +553,6 @@ class TestPythonRegistration(TestCase):
|
|||||||
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
|
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
|
||||||
# ops that don't have a fallthrough registered should not be affected
|
# ops that don't have a fallthrough registered should not be affected
|
||||||
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
|
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
|
||||||
finally:
|
|
||||||
del my_lib
|
|
||||||
|
|
||||||
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
||||||
# default behavior should have been restored
|
# default behavior should have been restored
|
||||||
@ -694,13 +668,13 @@ $5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_fo
|
|||||||
print("woof")
|
print("woof")
|
||||||
return torch.empty(())
|
return torch.empty(())
|
||||||
|
|
||||||
my_lib = Library("my_lib", "DEF")
|
with _scoped_library("my_lib", "DEF") as my_lib:
|
||||||
my_lib.define("weird(Tensor?[] self) -> Tensor")
|
my_lib.define("weird(Tensor?[] self) -> Tensor")
|
||||||
my_lib.impl("weird", weird, "CPU")
|
my_lib.impl("weird", weird, "CPU")
|
||||||
with capture_logs() as logs:
|
with capture_logs() as logs:
|
||||||
x = LoggingTensor(torch.ones(2, 2))
|
x = LoggingTensor(torch.ones(2, 2))
|
||||||
log_input("x", x)
|
log_input("x", x)
|
||||||
torch.ops.my_lib.weird.default([None, x])
|
torch.ops.my_lib.weird.default([None, x])
|
||||||
|
|
||||||
self.assertExpectedInline('\n'.join(logs), '''\
|
self.assertExpectedInline('\n'.join(logs), '''\
|
||||||
$0: f32[2, 2] = input('x')
|
$0: f32[2, 2] = input('x')
|
||||||
@ -1485,28 +1459,29 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
|||||||
t.record_stream(s)
|
t.record_stream(s)
|
||||||
|
|
||||||
def test_return_stream(self) -> None:
|
def test_return_stream(self) -> None:
|
||||||
l_def = torch.library.Library("test_return_stream", "DEF")
|
with _scoped_library("test_return_stream", "DEF") as l_def:
|
||||||
l_def.define("return_stream(Tensor self) -> Stream")
|
l_def.define("return_stream(Tensor self) -> Stream")
|
||||||
l_impl = torch.library.Library("test_return_stream", "IMPL", "CPU")
|
with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
|
||||||
l_impl.impl("return_stream", lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2))
|
l_impl.impl("return_stream",
|
||||||
|
lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2))
|
||||||
|
|
||||||
class TestMode(TorchDispatchMode):
|
class TestMode(TorchDispatchMode):
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
return torch.Stream(stream_id=1, device_index=2, device_type=3)
|
return torch.Stream(stream_id=1, device_index=2, device_type=3)
|
||||||
|
|
||||||
t = torch.tensor(5.)
|
t = torch.tensor(5.)
|
||||||
s = torch.ops.test_return_stream.return_stream(t)
|
s = torch.ops.test_return_stream.return_stream(t)
|
||||||
self.assertIsInstance(s, torch.Stream)
|
self.assertIsInstance(s, torch.Stream)
|
||||||
self.assertEqual(s.stream_id, 0)
|
self.assertEqual(s.stream_id, 0)
|
||||||
self.assertEqual(s.device_index, 1)
|
self.assertEqual(s.device_index, 1)
|
||||||
self.assertEqual(s.device_type, 2)
|
self.assertEqual(s.device_type, 2)
|
||||||
|
|
||||||
with TestMode():
|
with TestMode():
|
||||||
s = torch.ops.test_return_stream.return_stream(t)
|
s = torch.ops.test_return_stream.return_stream(t)
|
||||||
self.assertIsInstance(s, torch.Stream)
|
self.assertIsInstance(s, torch.Stream)
|
||||||
self.assertEqual(s.stream_id, 1)
|
self.assertEqual(s.stream_id, 1)
|
||||||
self.assertEqual(s.device_index, 2)
|
self.assertEqual(s.device_index, 2)
|
||||||
self.assertEqual(s.device_type, 3)
|
self.assertEqual(s.device_type, 3)
|
||||||
|
|
||||||
def test_subclass_autograd_device_check(self) -> None:
|
def test_subclass_autograd_device_check(self) -> None:
|
||||||
class NonWrapperSubclass(torch.Tensor):
|
class NonWrapperSubclass(torch.Tensor):
|
||||||
|
@ -26,17 +26,17 @@ def secretly_mutating(x):
|
|||||||
def output_is_input(x):
|
def output_is_input(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
custom_lib = torch.library.Library("bad_schemas", "DEF")
|
custom_lib = torch.library.Library("bad_schemas", "DEF") # noqa: TOR901
|
||||||
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
|
custom_lib.define("secretly_aliasing(Tensor x) -> Tensor")
|
||||||
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
|
custom_lib.define("secretly_mutating(Tensor x) -> Tensor")
|
||||||
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
|
custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)")
|
||||||
|
|
||||||
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU")
|
custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") # noqa: TOR901
|
||||||
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
|
custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing)
|
||||||
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
|
custom_lib_cpu.impl("secretly_mutating", secretly_mutating)
|
||||||
custom_lib_cpu.impl("output_is_input", output_is_input)
|
custom_lib_cpu.impl("output_is_input", output_is_input)
|
||||||
|
|
||||||
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta")
|
custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") # noqa: TOR901
|
||||||
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
|
custom_lib_meta.impl("secretly_aliasing", secretly_aliasing)
|
||||||
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
|
custom_lib_meta.impl("secretly_mutating", secretly_mutating)
|
||||||
custom_lib_meta.impl("output_is_input", output_is_input)
|
custom_lib_meta.impl("output_is_input", output_is_input)
|
||||||
|
@ -10205,7 +10205,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||||||
from torch.library import Library, impl
|
from torch.library import Library, impl
|
||||||
global _my_storage
|
global _my_storage
|
||||||
|
|
||||||
my_lib = Library("my_lib", "DEF")
|
my_lib = Library("my_lib", "DEF") # noqa: TOR901
|
||||||
my_lib.define('my_func() -> None')
|
my_lib.define('my_func() -> None')
|
||||||
|
|
||||||
a = torch.tensor([1.])
|
a = torch.tensor([1.])
|
||||||
|
@ -2089,7 +2089,6 @@ dynamo_expected_failures = {
|
|||||||
"TestPythonRegistration.test_alias_analysis", # test_python_dispatch
|
"TestPythonRegistration.test_alias_analysis", # test_python_dispatch
|
||||||
"TestWrapperSubclassAliasingCPU.test_wrapper_subclass_aliasing_conv2d_cpu", # test_python_dispatch
|
"TestWrapperSubclassAliasingCPU.test_wrapper_subclass_aliasing_conv2d_cpu", # test_python_dispatch
|
||||||
"TestPythonRegistration.test_finalizer", # test_python_dispatch
|
"TestPythonRegistration.test_finalizer", # test_python_dispatch
|
||||||
"TestPythonRegistration.test_override_cpu_sum", # test_python_dispatch
|
|
||||||
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
|
"TestPythonDispatch.test_subclass_autograd_device_check", # test_python_dispatch
|
||||||
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
|
"TestPythonDispatch.test_make_subclass_with_modes", # test_python_dispatch
|
||||||
"LoggingTests.test_trace_source_nested", # dynamo/test_logging
|
"LoggingTests.test_trace_source_nested", # dynamo/test_logging
|
||||||
|
Reference in New Issue
Block a user