mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146140 Approved by: https://github.com/albanD
288 lines
8.6 KiB
Python
288 lines
8.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from typing import Callable, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch._dynamo.testing import CompileCounter, same
|
|
|
|
|
|
"""
|
|
This is an example of a pure-python version of autograd implemented by
|
|
@zdevito. It represents a rather challenging test case for TorchDynamo
|
|
to push the limits of what it can do.
|
|
"""
|
|
|
|
|
|
_name: int = 0
|
|
|
|
|
|
def fresh_name() -> str:
|
|
"""create a new unique name for a variable: v0, v1, v2"""
|
|
global _name
|
|
r = f"v{_name}"
|
|
_name += 1
|
|
return r
|
|
|
|
|
|
class Variable:
|
|
def __init__(self, value: torch.Tensor, name: Optional[str] = None):
|
|
self.value = value
|
|
self.name = name or fresh_name()
|
|
|
|
# We need to start with some tensors whose values were not computed
|
|
# inside the autograd. This function constructs leaf nodes.
|
|
@staticmethod
|
|
def constant(value: torch.Tensor, name: Optional[str] = None):
|
|
return Variable(value, name)
|
|
|
|
def __repr__(self):
|
|
return repr(self.value)
|
|
|
|
# This performs a pointwise multiplication of a Variable, tracking gradients
|
|
def __mul__(self, rhs: "Variable") -> "Variable":
|
|
# defined later in the notebook
|
|
return operator_mul(self, rhs)
|
|
|
|
def __add__(self, rhs: "Variable") -> "Variable":
|
|
return operator_add(self, rhs)
|
|
|
|
def sum(self, name: Optional[str] = None) -> "Variable":
|
|
return operator_sum(self, name)
|
|
|
|
def expand(self, sizes: list[int]) -> "Variable":
|
|
return operator_expand(self, sizes)
|
|
|
|
|
|
class TapeEntry(NamedTuple):
|
|
# names of the inputs to the original computation
|
|
inputs: list[str]
|
|
# names of the outputs of the original computation
|
|
outputs: list[str]
|
|
# apply chain rule
|
|
propagate: "Callable[list[Variable], list[Variable]]"
|
|
|
|
|
|
gradient_tape: list[TapeEntry] = []
|
|
|
|
|
|
def reset_tape():
|
|
gradient_tape.clear()
|
|
global _name
|
|
_name = 0
|
|
|
|
|
|
def grad(L, desired_results: list[Variable]) -> list[Variable]:
|
|
# this map holds dL/dX for all values X
|
|
dL_d: dict[str, Variable] = {}
|
|
# It starts by initializing the 'seed' dL/dL, which is 1
|
|
dL_d[L.name] = Variable(torch.ones(()))
|
|
# print(f'd{L.name} ------------------------')
|
|
|
|
# look up dL_dentries. If a variable is never used to compute the loss,
|
|
# we consider its gradient None, see the note below about zeros for more information.
|
|
def gather_grad(entries: list[str]):
|
|
return [dL_d[entry] if entry in dL_d else None for entry in entries]
|
|
|
|
# propagate the gradient information backward
|
|
for entry in reversed(gradient_tape):
|
|
dL_doutputs = gather_grad(entry.outputs)
|
|
if all(dL_doutput is None for dL_doutput in dL_doutputs):
|
|
# optimize for the case where some gradient pathways are zero. See
|
|
# The note below for more details.
|
|
continue
|
|
|
|
# perform chain rule propagation specific to each compute
|
|
dL_dinputs = entry.propagate(dL_doutputs)
|
|
|
|
# Accumulate the gradient produced for each input.
|
|
# Each use of a variable produces some gradient dL_dinput for that
|
|
# use. The multivariate chain rule tells us it is safe to sum
|
|
# all the contributions together.
|
|
for input, dL_dinput in zip(entry.inputs, dL_dinputs):
|
|
if input not in dL_d:
|
|
dL_d[input] = dL_dinput
|
|
else:
|
|
dL_d[input].value += dL_dinput.value
|
|
|
|
# print some information to understand the values of each intermediate
|
|
# for name, value in dL_d.items():
|
|
# print(f'd{L.name}_d{name} = {value.name}')
|
|
# print(f'------------------------')
|
|
|
|
return gather_grad(desired.name for desired in desired_results)
|
|
|
|
|
|
def operator_mul(self: Variable, rhs: Variable) -> Variable:
|
|
if isinstance(rhs, float) and rhs == 1.0:
|
|
# peephole optimization
|
|
return self
|
|
|
|
# define forward
|
|
r = Variable(self.value * rhs.value)
|
|
# print(f'{r.name} = {self.name} * {rhs.name}')
|
|
|
|
# record what the inputs and outputs of the op were
|
|
inputs = [self.name, rhs.name]
|
|
outputs = [r.name]
|
|
|
|
# define backprop
|
|
def propagate(dL_doutputs: list[Variable]):
|
|
(dL_dr,) = dL_doutputs
|
|
|
|
dr_dself = rhs # partial derivative of r = self*rhs
|
|
dr_drhs = self # partial derivative of r = self*rhs
|
|
|
|
# chain rule propagation from outputs to inputs of multiply
|
|
dL_dself = dL_dr * dr_dself
|
|
dL_drhs = dL_dr * dr_drhs
|
|
dL_dinputs = [dL_dself, dL_drhs]
|
|
return dL_dinputs
|
|
|
|
# finally, we record the compute we did on the tape
|
|
gradient_tape.append(TapeEntry(inputs=inputs, outputs=outputs, propagate=propagate))
|
|
return r
|
|
|
|
|
|
def operator_add(self: Variable, rhs: Variable) -> Variable:
|
|
# Add follows a similar pattern to Mul, but it doesn't end up
|
|
# capturing any variables.
|
|
r = Variable(self.value + rhs.value)
|
|
# print(f'{r.name} = {self.name} + {rhs.name}')
|
|
|
|
def propagate(dL_doutputs: list[Variable]):
|
|
(dL_dr,) = dL_doutputs
|
|
dr_dself = 1.0
|
|
dr_drhs = 1.0
|
|
dL_dself = dL_dr * dr_dself
|
|
dL_drhs = dL_dr * dr_drhs
|
|
return [dL_dself, dL_drhs]
|
|
|
|
gradient_tape.append(
|
|
TapeEntry(inputs=[self.name, rhs.name], outputs=[r.name], propagate=propagate)
|
|
)
|
|
return r
|
|
|
|
|
|
def operator_sum(self: Variable, name: Optional[str]) -> "Variable":
|
|
r = Variable(torch.sum(self.value), name=name)
|
|
# print(f'{r.name} = {self.name}.sum()')
|
|
|
|
def propagate(dL_doutputs: list[Variable]):
|
|
(dL_dr,) = dL_doutputs
|
|
size = self.value.size()
|
|
return [dL_dr.expand(*size)]
|
|
|
|
gradient_tape.append(
|
|
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
|
|
)
|
|
return r
|
|
|
|
|
|
def operator_expand(self: Variable, sizes: list[int]) -> "Variable":
|
|
assert self.value.dim() == 0 # only works for scalars
|
|
r = Variable(self.value.expand(sizes))
|
|
# print(f'{r.name} = {self.name}.expand({sizes})')
|
|
|
|
def propagate(dL_doutputs: list[Variable]):
|
|
(dL_dr,) = dL_doutputs
|
|
return [dL_dr.sum()]
|
|
|
|
gradient_tape.append(
|
|
TapeEntry(inputs=[self.name], outputs=[r.name], propagate=propagate)
|
|
)
|
|
return r
|
|
|
|
|
|
def simple(a, b):
|
|
t = a + b
|
|
return t * b
|
|
|
|
|
|
class TestPythonAutograd(TestCase):
|
|
def _common(self, fn, expected_ops):
|
|
args1 = [torch.randn(10), torch.randn(10)]
|
|
args2 = [torch.randn(10), torch.randn(10)]
|
|
cnt = CompileCounter()
|
|
fn_dynamo = torch._dynamo.optimize_assert(cnt)(fn)
|
|
reset_tape()
|
|
res1 = fn_dynamo(*args1)
|
|
reset_tape()
|
|
res2 = fn_dynamo(*args2)
|
|
reset_tape()
|
|
self.assertTrue(same(res1, fn(*args1)))
|
|
reset_tape()
|
|
self.assertTrue(same(res2, fn(*args2)))
|
|
reset_tape()
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, expected_ops)
|
|
|
|
def test_forwards1(self):
|
|
def fn(a, b):
|
|
a = Variable.constant(a, name="a")
|
|
b = Variable.constant(b, name="b")
|
|
loss = simple(a, b).sum()
|
|
return loss
|
|
|
|
self._common(fn, 3)
|
|
|
|
def test_forwards2(self):
|
|
def fn(a, b):
|
|
reset_tape()
|
|
a = Variable.constant(a, name="a")
|
|
b = Variable.constant(b, name="b")
|
|
loss = simple(a, b).sum()
|
|
reset_tape()
|
|
return loss
|
|
|
|
self._common(fn, 3)
|
|
|
|
def test_backwards1(self):
|
|
def fn(a, b):
|
|
a = Variable.constant(a, name="a")
|
|
b = Variable.constant(b, name="b")
|
|
loss = simple(a, b).sum()
|
|
return grad(loss, [a, b])
|
|
|
|
self._common(fn, 8)
|
|
|
|
def test_backwards2(self):
|
|
def fn(a, b):
|
|
reset_tape()
|
|
a = Variable.constant(a, name="a")
|
|
b = Variable.constant(b, name="b")
|
|
loss = simple(a, b).sum()
|
|
res = grad(loss, [a, b])
|
|
reset_tape()
|
|
return res
|
|
|
|
self._common(fn, 8)
|
|
|
|
def test_split(self):
|
|
v1 = Variable.constant(torch.randn(10), name="a")
|
|
v2 = Variable.constant(torch.randn(10), name="b")
|
|
cnt = CompileCounter()
|
|
|
|
def forward(a, b):
|
|
return simple(a, b).sum()
|
|
|
|
reset_tape()
|
|
loss1 = forward(v1, v2)
|
|
grad1 = grad(loss1, [v1, v2])
|
|
|
|
reset_tape()
|
|
opt_forward = torch._dynamo.optimize_assert(cnt)(forward)
|
|
opt_grad = torch._dynamo.optimize_assert(cnt)(grad)
|
|
loss2 = opt_forward(v1, v2)
|
|
# force two frames
|
|
grad2 = opt_grad(loss2, [v1, v2])
|
|
|
|
self.assertTrue(same(loss1, loss2))
|
|
self.assertTrue(same(grad1, grad2))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 8)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|