mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove useless parentheses in `raise` statements if the exception type is raised with no argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261 Approved by: https://github.com/albanD
1027 lines
31 KiB
Python
1027 lines
31 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import copy
|
|
import math
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._dynamo.utils
|
|
from torch.testing._internal.common_utils import skipIfRocm
|
|
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
|
|
|
if HAS_CUDA:
|
|
import triton
|
|
from torch.testing._internal.triton_utils import add_kernel
|
|
|
|
|
|
class CustomFunc1(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return foo + foo
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class CustomFunc3(torch.autograd.Function):
|
|
# Test there is graph break in forward function
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
torch._dynamo.graph_break()
|
|
result = result + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * math.sqrt(result.numel())
|
|
|
|
|
|
class Module1(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc1().apply(foo)
|
|
|
|
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc1.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module3(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc1().apply(foo)
|
|
|
|
|
|
class Module4(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc1.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module5(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc3().apply(foo)
|
|
|
|
|
|
class Module6(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc3.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class LinearFunction(torch.autograd.Function):
|
|
# Note that forward, setup_context, and backward are @staticmethods
|
|
@staticmethod
|
|
def forward(input, weight, bias):
|
|
output = input.mm(weight.t())
|
|
if bias is not None:
|
|
output += bias.unsqueeze(0).expand_as(output)
|
|
return output
|
|
|
|
@staticmethod
|
|
# inputs is a Tuple of all of the inputs passed to forward.
|
|
# output is the output of the forward().
|
|
def setup_context(ctx, inputs, output):
|
|
input, weight, bias = inputs
|
|
ctx.save_for_backward(input, weight, bias)
|
|
|
|
# This function has only a single output, so it gets only one gradient
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight, bias = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
if ctx.needs_input_grad[0]:
|
|
grad_input = grad_output.mm(weight)
|
|
if ctx.needs_input_grad[1]:
|
|
grad_weight = grad_output.t().mm(input)
|
|
if bias is not None and ctx.needs_input_grad[2]:
|
|
grad_bias = grad_output.sum(0)
|
|
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
|
|
class ModuleLinear(torch.nn.Module):
|
|
def forward(self, input, weight, bias=None):
|
|
return LinearFunction.apply(input, weight, bias)
|
|
|
|
|
|
class MaterializingGradFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.set_materialize_grads(False)
|
|
return x.clone(), x.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out1, grad_out2):
|
|
return grad_out1, grad_out2
|
|
|
|
|
|
class MaterializingGradModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MaterializingGradFunction.apply(x)
|
|
|
|
|
|
class CustomFuncBwdPrintGraphBreak(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return torch.add(foo, foo)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
print("graph break!")
|
|
return grad_output
|
|
|
|
|
|
class CustomFuncBwdPrintModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return CustomFuncBwdPrintGraphBreak.apply(x)
|
|
|
|
|
|
class CustomFuncStrideBwd(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return torch.add(foo, foo)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output.stride()
|
|
|
|
|
|
class CustomFuncStrideModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return CustomFuncStrideBwd.apply(x)
|
|
|
|
|
|
class CustomFuncSaveForBwd(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
result = result + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * math.sqrt(result.numel())
|
|
|
|
|
|
class SaveForBwdModule(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFuncSaveForBwd().apply(foo)
|
|
|
|
|
|
class ContextSaveAndMark(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
with torch.no_grad():
|
|
ctx.save_for_backward(x)
|
|
ctx.mark_non_differentiable(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class ContextMarkAndSave(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
with torch.no_grad():
|
|
ctx.mark_non_differentiable(x)
|
|
ctx.save_for_backward(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class ModuleWithGradFunc(torch.nn.Module):
|
|
def __init__(self, func):
|
|
super().__init__()
|
|
self.f = func.apply
|
|
|
|
def forward(self, x):
|
|
return self.f(x)
|
|
|
|
|
|
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|
# Sound behaviors, tested for working capture
|
|
def test_autograd_function_equivalence(self):
|
|
for grad in [True, False]:
|
|
for i in range(1, 5):
|
|
torch._dynamo.reset()
|
|
model = globals()[f"Module{i}"]()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
opt_model(torch.ones(2, 3, requires_grad=grad)),
|
|
torch.tensor([2.0], requires_grad=grad),
|
|
)
|
|
)
|
|
|
|
def test_autograd_function_has_graph_break(self):
|
|
for grad in [True, False]:
|
|
x = torch.randn(10, requires_grad=grad)
|
|
for model in [Module5(), Module6()]:
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnts)(model)
|
|
for _ in range(3):
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_linear_setup_context(self):
|
|
model = ModuleLinear()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
|
optim_result = opt_model(input, weight)
|
|
eager_result = model(input, weight)
|
|
self.assertEqual(optim_result, eager_result)
|
|
|
|
def test_materialize_grad(self):
|
|
model = MaterializingGradModule()
|
|
opt_model = torch._dynamo.optimize("eager")(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
optim_result = opt_model(x)
|
|
eager_result = model(x)
|
|
self.assertEqual(optim_result, eager_result)
|
|
|
|
def test_print_in_bwd(self):
|
|
model = CustomFuncBwdPrintModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
|
|
opt_model(x)
|
|
|
|
def test_stride_in_bwd(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
model = CustomFuncStrideModule()
|
|
opt_model = torch.compile(backend=cnt)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
# graph break: Illegal getattr invocation stride in strict mod.
|
|
self.assertEqual(
|
|
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
|
|
)
|
|
|
|
def test_enum_arg(self):
|
|
from enum import Enum
|
|
|
|
class SomeEnum(Enum):
|
|
A = 0
|
|
B = 1
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, e):
|
|
if e is SomeEnum.A:
|
|
return x.sin()
|
|
else:
|
|
return x.cos()
|
|
|
|
@staticmethod
|
|
def backward(ctx, g):
|
|
return g
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x, enum):
|
|
output = Foo.apply(
|
|
x,
|
|
enum,
|
|
)
|
|
return output
|
|
|
|
x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True)
|
|
y = f(x, SomeEnum.A)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
def test_save_for_bwd(self):
|
|
model = SaveForBwdModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
|
opt_model(x)
|
|
|
|
def test_allow_in_graph(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.allow_in_graph
|
|
class AllowInGraphFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
torch._dynamo.graph_break()
|
|
ctx.x0 = x.size(0)
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out * ctx.x0
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
return AllowInGraphFunc.apply(x)
|
|
|
|
x = torch.rand(2, 3, requires_grad=True)
|
|
result = fn(x)
|
|
|
|
self.assertEqual(result, AllowInGraphFunc.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_once_differentiable(self):
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class ScaleGradient(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, grad):
|
|
return grad * 0.5
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
return ScaleGradient.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = fn(x)
|
|
|
|
self.assertEqual(result, ScaleGradient.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_classmethod(self):
|
|
class Shake(torch.autograd.Function):
|
|
@classmethod
|
|
def forward(cls, ctx, foo):
|
|
return foo + foo
|
|
|
|
@classmethod
|
|
def backward(cls, ctx, grad_output):
|
|
return grad_output
|
|
|
|
def f(x):
|
|
return Shake.apply(x)
|
|
|
|
x = torch.randn(4, 4, 4, 4, requires_grad=True)
|
|
opt_m = torch.compile(backend="eager")(f)
|
|
opt_m(x)
|
|
|
|
def test_function_context_save_and_mark(self):
|
|
mod = ModuleWithGradFunc(ContextSaveAndMark)
|
|
args, kwargs = ([torch.rand([1])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
torch._dynamo.reset()
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
def test_function_context_mark_and_save(self):
|
|
mod = ModuleWithGradFunc(ContextMarkAndSave)
|
|
args, kwargs = ([torch.rand([1])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
torch._dynamo.reset()
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
def test_multi_output(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.clone(), x.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad1, grad2):
|
|
return grad1 + grad2
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return Foo.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, Foo.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_amp_custom_fwd_bwd(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class MyMM(torch.autograd.Function):
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_fwd
|
|
def forward(ctx, a, b):
|
|
ctx.save_for_backward(a, b)
|
|
return a.mm(b)
|
|
|
|
@staticmethod
|
|
@torch.cuda.amp.custom_bwd
|
|
def backward(ctx, grad):
|
|
a, b = ctx.saved_tensors
|
|
return grad.mm(b.t()), a.t().mm(grad)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
return MyMM.apply(a, b)
|
|
|
|
a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True)
|
|
grad = a.clone()
|
|
res = fn(a, a)
|
|
res.backward(grad)
|
|
|
|
self.assertEqual(res, MyMM.apply(a, a))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_graph_break_if_lifted_free_variable(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
delta = torch.randn(3)
|
|
|
|
class Foo(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.clone(), (x + delta).clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad1, grad2):
|
|
return grad1 + grad2
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
return Foo.apply(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
result = f(x)
|
|
|
|
self.assertEqual(result, Foo.apply(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(
|
|
list(torch._dynamo.utils.counters["graph_break"].values()), [1]
|
|
)
|
|
|
|
def test_function_with_bound_free_variable(self):
|
|
class LowerBound(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inputs, bound):
|
|
ctx.save_for_backward(inputs, inputs.new_ones(1) * bound)
|
|
return inputs.clamp(min=bound)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
inputs, bound = ctx.saved_tensors
|
|
return (inputs >= bound) * grad_output, None
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))
|
|
|
|
def forward(self, x):
|
|
gamma = LowerBound.apply(self.gamma, 1)
|
|
return x + gamma
|
|
|
|
mod = MyMod()
|
|
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
|
|
before = mod(*args, **kwargs)
|
|
|
|
compiled_model = torch._dynamo.optimize("eager")(mod)
|
|
after = compiled_model(*args, **kwargs)
|
|
self.assertEqual(before, after)
|
|
|
|
# I pulled all of these test cases from test_autograd.py
|
|
# In the future, we should make the Dynamo test suite actually
|
|
# run on test_autograd.py (it's disabled right now) and delete these.
|
|
def test_smoke_from_test_autograd(self):
|
|
class Func(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
out0 = x.clone()
|
|
out1 = x.clone()
|
|
ctx.mark_non_differentiable(out1)
|
|
ctx._materialize_non_diff_grads = False
|
|
return out0, out1
|
|
|
|
@staticmethod
|
|
def backward(ctx, g0, g1):
|
|
assert g1 is None
|
|
return g0
|
|
|
|
def mult1(x):
|
|
return x.prod(dim=-1).prod(dim=-1)
|
|
|
|
class Mult(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = mult1(x)
|
|
ctx.save_for_backward(x, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return (grad_output * y)[:, None, None] / x
|
|
|
|
mult2 = Mult.apply
|
|
|
|
class Double(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
y = x**2
|
|
ctx.save_for_backward(x, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, _ = ctx.saved_tensors
|
|
return grad_output * 2 * x
|
|
|
|
# this is equivalent, but uses the output of .forward() in .backward()
|
|
class Double2(Double):
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return grad_output * 2 * y / x
|
|
|
|
double = Double.apply
|
|
double2 = Double2.apply
|
|
|
|
class Identity(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, a, b):
|
|
return a, a + b
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_a, grad_b):
|
|
return grad_a + grad_b, grad_b
|
|
|
|
class MyFunc2(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp):
|
|
return inp.clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gO):
|
|
return torch.tensor(float("nan")).expand(10, 10)
|
|
|
|
def run_fn(a):
|
|
out = MyFunc2.apply(a)
|
|
return out.sum()
|
|
|
|
class MyFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp):
|
|
return inp.view_as(inp)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad
|
|
|
|
class MyAdder(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, a, b):
|
|
a.add_(b)
|
|
ctx.mark_dirty(a)
|
|
return a
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad, grad
|
|
|
|
class InplaceMul(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
result = x.mul_(2)
|
|
ctx.mark_dirty(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
pass
|
|
|
|
@staticmethod
|
|
def jvp(ctx, x_t):
|
|
if jvp_err: # noqa: F821
|
|
return x_t
|
|
else:
|
|
return x_t.mul_(2)
|
|
|
|
class MyFn2(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
return x + y, x
|
|
|
|
@staticmethod
|
|
def vjp(ctx, gO1, gO2):
|
|
return gO1 + gO2, gO1
|
|
|
|
@staticmethod
|
|
def jvp(ctx, x_t, y_t):
|
|
return x_t + y_t, fn(x_t) # noqa: F821
|
|
|
|
class MyFn3(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp, inplace):
|
|
view = inp.clone()[:3]
|
|
if inplace:
|
|
view += 2
|
|
return view
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad):
|
|
return grad, None
|
|
|
|
def test():
|
|
a = torch.tensor(1.0, requires_grad=True)
|
|
out = Func.apply(a)[0]
|
|
out.backward()
|
|
|
|
x = torch.ones(2, 4, 4).requires_grad_()
|
|
mult2(x)
|
|
|
|
x = torch.tensor(2).double().requires_grad_()
|
|
double(x)
|
|
double2(x)
|
|
|
|
x = torch.randn(5, 5, requires_grad=True)
|
|
y = torch.randn(5, 5, requires_grad=True)
|
|
q, p = Identity.apply(x, y)
|
|
|
|
a = torch.rand(1, 2)
|
|
b = torch.rand(1, requires_grad=True)
|
|
view_a = MyFn.apply(a)
|
|
|
|
a = torch.ones(2, requires_grad=True)
|
|
b = torch.ones(2, requires_grad=True)
|
|
c = MyAdder.apply(a.clone(), b)
|
|
c.sum().backward()
|
|
|
|
z = torch.tensor(1.0, requires_grad=True)
|
|
x = z.clone()
|
|
y = InplaceMul.apply(x)
|
|
|
|
a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
|
|
b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
|
|
c = torch.tensor(1.0, dtype=torch.double)
|
|
d = torch.tensor(1.0, dtype=torch.double)
|
|
MyFn2.apply(a, b)
|
|
MyFn2.apply(c, d)
|
|
|
|
base = torch.rand(10, requires_grad=True)
|
|
foo = MyFn3.apply(base, False)
|
|
|
|
test()
|
|
opt_test = torch._dynamo.optimize("eager")(test)
|
|
opt_test()
|
|
|
|
def test_tensor_subclass_intermediary_input(self):
|
|
class FooTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(cls, data, config, scale):
|
|
self = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
config[0],
|
|
strides=config[1],
|
|
storage_offset=config[2],
|
|
dtype=config[3],
|
|
layout=config[4],
|
|
requires_grad=config[5],
|
|
device=data.device,
|
|
)
|
|
self._data = data
|
|
self._config = config
|
|
self._scale = scale
|
|
return self
|
|
|
|
def __repr__(self):
|
|
return "FooTensor"
|
|
|
|
def __tensor_flatten__(self):
|
|
return ("_data",), (
|
|
self._config,
|
|
self._scale,
|
|
)
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
|
|
return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args, kwargs=None):
|
|
# handling clone and view is so dynamo fakefication passes, it's not
|
|
# intended to be handling user code
|
|
if func == torch.ops.aten.clone.default:
|
|
return FooTensor(
|
|
args[0]._data.clone(), args[0]._config, args[0]._scale
|
|
)
|
|
elif func == torch.ops.aten.view.default:
|
|
new_data = args[0]._data.view(*args[1:])
|
|
return FooTensor(new_data, args[0]._config, args[0]._scale)
|
|
|
|
raise NotImplementedError
|
|
|
|
class foo_autograd_fn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
# access some data from `x`, where `x` is a tensor subclass
|
|
x2 = x._data + 1.0
|
|
# create and return a tensor subclass from within a torch.autograd.Function
|
|
x3 = FooTensor(x2, x._config, x._scale)
|
|
return x3._data
|
|
|
|
@staticmethod
|
|
def backward(ctx, g):
|
|
return g
|
|
|
|
x_ref = torch.randn(4, 4).requires_grad_(True)
|
|
x = copy.deepcopy(x_ref)
|
|
scale = torch.tensor(1.0)
|
|
# Weird that this is needed, but not having this breaks a lot of things
|
|
torch._dynamo.allow_in_graph(FooTensor)
|
|
|
|
def foo(x, scale):
|
|
config = (
|
|
x.size(),
|
|
x.stride(),
|
|
x.storage_offset(),
|
|
x.dtype,
|
|
x.layout,
|
|
x.requires_grad,
|
|
)
|
|
x = FooTensor(x, config, scale)
|
|
x = foo_autograd_fn.apply(x)
|
|
return x
|
|
|
|
y_ref = foo(x_ref, scale)
|
|
y_ref.sum().backward()
|
|
|
|
foo_opt = torch.compile(foo, backend="eager")
|
|
y = foo_opt(x, scale)
|
|
y.sum().backward()
|
|
|
|
self.assertEqual(y, y_ref)
|
|
self.assertEqual(x.grad, x_ref.grad)
|
|
|
|
def test_smuggle_symint_issue_111031(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.x0 = x.size(0)
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out * ctx.x0
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
|
|
def foo(x):
|
|
return Foo.apply(x)
|
|
|
|
foo(torch.randn(2, requires_grad=True))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
def test_needs_input_grad(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class NeedsInputGradFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
if ctx.needs_input_grad[0]:
|
|
return grad_output * result.sin()
|
|
return None
|
|
|
|
x = torch.randn(10, requires_grad=True)
|
|
NeedsInputGradFunc.apply(x).sum().backward()
|
|
self.assertEqual(x.grad.shape, x.shape)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_repeated_save_for_backward_calls(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
ctx.save_for_backward(x)
|
|
ctx.save_for_backward(x, y)
|
|
return x * y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x, y = ctx.saved_tensors
|
|
return grad_out * x, grad_out * y
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def foo(x, y):
|
|
return Foo.apply(x, y)
|
|
|
|
x_ref = torch.randn(2, requires_grad=True)
|
|
y_ref = torch.randn(2, requires_grad=True)
|
|
x_test = x_ref.clone().detach().requires_grad_()
|
|
y_test = y_ref.clone().detach().requires_grad_()
|
|
|
|
out_ref = foo(x_ref, y_ref)
|
|
out_ref.sum().backward()
|
|
|
|
out_test = torch.compile(foo, backend=cnts)(x_test, y_test)
|
|
out_test.sum().backward()
|
|
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(x_ref.grad, x_test.grad)
|
|
self.assertEqual(y_ref.grad, y_test.grad)
|
|
|
|
def test_smuggle_tensor_and_complex_structures(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.x0 = x
|
|
ctx.x1 = [1, 2, 3]
|
|
return x * 2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
x0mul = grad_out * ctx.x0
|
|
for i in ctx.x1:
|
|
x0mul = (x0mul * i) + x0mul
|
|
return x0mul
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnts, fullgraph=True, dynamic=True)
|
|
def foo(x):
|
|
return Foo.apply(x)
|
|
|
|
foo(torch.randn(2, requires_grad=True))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
def test_default_values(self):
|
|
from torch.autograd import Function
|
|
|
|
class Foo(Function):
|
|
@staticmethod
|
|
def forward(ctx, x, alpha=0.99):
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out
|
|
|
|
@torch.compile
|
|
def foo(x):
|
|
return Foo.apply(x)
|
|
|
|
# Make sure guards for default values do not crash
|
|
foo(torch.randn(2))
|
|
foo(torch.randn(2, requires_grad=True))
|
|
|
|
def test_tuple_arg(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
class TupleArgFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, shape):
|
|
ctx.save_for_backward(torch.randn(shape))
|
|
return x + 1
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return result, None
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn():
|
|
return TupleArgFunc.apply(x, shape)
|
|
|
|
shape = (10, 10)
|
|
x = torch.randn(shape, requires_grad=True)
|
|
out = fn()
|
|
out.sum().backward()
|
|
self.assertEqual(out, x + 1)
|
|
self.assertEqual(x.grad.shape, shape)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
@requires_cuda
|
|
@skipIfRocm
|
|
def test_triton_kernel_basic(self):
|
|
class Add(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
ctx.save_for_backward(x, y)
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x, y = ctx.saved_tensors
|
|
return x * grad_output, y * grad_output
|
|
|
|
@torch.compile(fullgraph=True, backend="inductor")
|
|
def f(x, y):
|
|
z = Add.apply(x, y)
|
|
return z
|
|
|
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
|
y = torch.randn(10, device="cuda", requires_grad=True)
|
|
z = f(x, y)
|
|
loss = z.sum()
|
|
loss.backward()
|
|
self.assertEqual(x + y, z)
|
|
|
|
@requires_cuda
|
|
@skipIfRocm
|
|
def test_triton_kernel_multiple_out(self):
|
|
class Add(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, y):
|
|
ctx.save_for_backward(x, y)
|
|
ctx.t1 = x
|
|
ctx.t2 = y
|
|
output = torch.zeros_like(x)
|
|
n_elements = output.numel()
|
|
grid = lambda meta: ( # noqa: E731
|
|
triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
|
|
)
|
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
|
return output, x
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, old_x):
|
|
x, y = ctx.saved_tensors
|
|
x1 = ctx.t1
|
|
y1 = ctx.t2
|
|
return old_x * x * x1 * grad_output, y * y1 * grad_output
|
|
|
|
@torch.compile(fullgraph=True, backend="inductor")
|
|
def f(x, y):
|
|
z = Add.apply(x, y)
|
|
return z
|
|
|
|
x = torch.randn(10, device="cuda", requires_grad=True)
|
|
y = torch.randn(10, device="cuda", requires_grad=True)
|
|
z, _ = f(x, y)
|
|
loss = z.sum()
|
|
loss.backward()
|
|
self.assertEqual(x + y, z)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|