mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129754 Approved by: https://github.com/ezyang
54 lines
1.0 KiB
Python
54 lines
1.0 KiB
Python
import torch
|
|
|
|
|
|
@torch.jit.script
|
|
def fn(x, scale, shift):
|
|
return scale * x / shift
|
|
|
|
|
|
@torch.jit.script
|
|
def recurrent(x, scale, shift):
|
|
y = x
|
|
for i in range(100):
|
|
y = fn(y, scale, shift)
|
|
return y
|
|
|
|
|
|
x = torch.randn(2, 2, device="cuda")
|
|
scale = torch.randn(2, 2, device="cuda", requires_grad=True)
|
|
shift = torch.randn(2, 2, device="cuda", requires_grad=True)
|
|
inputs = [x, scale, shift]
|
|
|
|
|
|
out = recurrent(x, scale, shift)
|
|
recurrent.graph_for(x, scale, shift)
|
|
|
|
|
|
import torch
|
|
|
|
|
|
@torch.jit.script
|
|
def recurrent_scaleshift(x, scale, shift):
|
|
y = x
|
|
for i in range(64):
|
|
y = scale * y + shift
|
|
return y
|
|
|
|
|
|
x = torch.randn(2, 2, device="cuda")
|
|
scale = torch.randn(2, 2, device="cuda", requires_grad=True)
|
|
shift = torch.randn(2, 2, device="cuda", requires_grad=True)
|
|
inputs = [x, scale, shift]
|
|
out = recurrent_scaleshift(x, scale, shift)
|
|
recurrent_scaleshift.graph_for(x, scale, shift)
|
|
|
|
|
|
import torch
|
|
|
|
|
|
x = torch.tensor([])
|
|
x.requires_grad = True
|
|
x.mean().backward() # no error triggered
|
|
x = x.cuda()
|
|
x.mean().backward()
|