mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129825 Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
183 lines
5.8 KiB
Python
183 lines
5.8 KiB
Python
import argparse
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
|
|
from .runner import get_nn_runners
|
|
|
|
|
|
def barf():
|
|
import pdb
|
|
|
|
pdb.set_trace()
|
|
|
|
|
|
def assertEqual(tensor, expected, threshold=0.001):
|
|
if isinstance(tensor, (list, tuple)):
|
|
for t, e in zip(tensor, expected):
|
|
assertEqual(t, e)
|
|
else:
|
|
if (tensor - expected).abs().max() > threshold:
|
|
barf()
|
|
|
|
|
|
def filter_requires_grad(tensors):
|
|
return [t for t in tensors if t.requires_grad]
|
|
|
|
|
|
def test_rnns(
|
|
experim_creator,
|
|
control_creator,
|
|
check_grad=True,
|
|
verbose=False,
|
|
seqLength=100,
|
|
numLayers=1,
|
|
inputSize=512,
|
|
hiddenSize=512,
|
|
miniBatch=64,
|
|
device="cuda",
|
|
seed=17,
|
|
):
|
|
creator_args = dict(
|
|
seqLength=seqLength,
|
|
numLayers=numLayers,
|
|
inputSize=inputSize,
|
|
hiddenSize=hiddenSize,
|
|
miniBatch=miniBatch,
|
|
device=device,
|
|
seed=seed,
|
|
)
|
|
|
|
print("Setting up...")
|
|
control = control_creator(**creator_args)
|
|
experim = experim_creator(**creator_args)
|
|
|
|
# Precondition
|
|
assertEqual(experim.inputs, control.inputs)
|
|
assertEqual(experim.params, control.params)
|
|
|
|
print("Checking outputs...")
|
|
control_outputs = control.forward(*control.inputs)
|
|
experim_outputs = experim.forward(*experim.inputs)
|
|
assertEqual(experim_outputs, control_outputs)
|
|
|
|
print("Checking grads...")
|
|
assert control.backward_setup is not None
|
|
assert experim.backward_setup is not None
|
|
assert control.backward is not None
|
|
assert experim.backward is not None
|
|
control_backward_inputs = control.backward_setup(control_outputs, seed)
|
|
experim_backward_inputs = experim.backward_setup(experim_outputs, seed)
|
|
|
|
control.backward(*control_backward_inputs)
|
|
experim.backward(*experim_backward_inputs)
|
|
|
|
control_grads = [p.grad for p in control.params]
|
|
experim_grads = [p.grad for p in experim.params]
|
|
assertEqual(experim_grads, control_grads)
|
|
|
|
if verbose:
|
|
print(experim.forward.graph_for(*experim.inputs))
|
|
print()
|
|
|
|
|
|
def test_vl_py(**test_args):
|
|
# XXX: This compares vl_py with vl_lstm.
|
|
# It's done this way because those two don't give the same outputs so
|
|
# the result isn't an apples-to-apples comparison right now.
|
|
control_creator = varlen_pytorch_lstm_creator
|
|
name, experim_creator, context = get_nn_runners("vl_py")[0]
|
|
with context():
|
|
print(f"testing {name}...")
|
|
creator_keys = [
|
|
"seqLength",
|
|
"numLayers",
|
|
"inputSize",
|
|
"hiddenSize",
|
|
"miniBatch",
|
|
"device",
|
|
"seed",
|
|
]
|
|
creator_args = {key: test_args[key] for key in creator_keys}
|
|
|
|
print("Setting up...")
|
|
control = control_creator(**creator_args)
|
|
experim = experim_creator(**creator_args)
|
|
|
|
# Precondition
|
|
assertEqual(experim.inputs, control.inputs[:2])
|
|
assertEqual(experim.params, control.params)
|
|
|
|
print("Checking outputs...")
|
|
control_out, control_hiddens = control.forward(*control.inputs)
|
|
control_hx, control_cx = control_hiddens
|
|
experim_out, experim_hiddens = experim.forward(*experim.inputs)
|
|
experim_hx, experim_cx = experim_hiddens
|
|
|
|
experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
|
|
assertEqual(experim_padded, control_out)
|
|
assertEqual(torch.cat(experim_hx, dim=1), control_hx)
|
|
assertEqual(torch.cat(experim_cx, dim=1), control_cx)
|
|
|
|
print("Checking grads...")
|
|
assert control.backward_setup is not None
|
|
assert experim.backward_setup is not None
|
|
assert control.backward is not None
|
|
assert experim.backward is not None
|
|
control_backward_inputs = control.backward_setup(
|
|
(control_out, control_hiddens), test_args["seed"]
|
|
)
|
|
experim_backward_inputs = experim.backward_setup(
|
|
(experim_out, experim_hiddens), test_args["seed"]
|
|
)
|
|
|
|
control.backward(*control_backward_inputs)
|
|
experim.backward(*experim_backward_inputs)
|
|
|
|
control_grads = [p.grad for p in control.params]
|
|
experim_grads = [p.grad for p in experim.params]
|
|
assertEqual(experim_grads, control_grads)
|
|
|
|
if test_args["verbose"]:
|
|
print(experim.forward.graph_for(*experim.inputs))
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Test lstm correctness")
|
|
|
|
parser.add_argument("--seqLength", default="100", type=int)
|
|
parser.add_argument("--numLayers", default="1", type=int)
|
|
parser.add_argument("--inputSize", default="512", type=int)
|
|
parser.add_argument("--hiddenSize", default="512", type=int)
|
|
parser.add_argument("--miniBatch", default="64", type=int)
|
|
parser.add_argument("--device", default="cuda", type=str)
|
|
parser.add_argument("--check-grad", "--check_grad", default="True", type=bool)
|
|
parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true")
|
|
parser.add_argument("--seed", default="17", type=int)
|
|
parser.add_argument("--verbose", action="store_true")
|
|
parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc")
|
|
args = parser.parse_args()
|
|
if args.rnns is None:
|
|
args.rnns = ["jit_premul", "jit"]
|
|
print(args)
|
|
|
|
if "cuda" in args.device:
|
|
assert torch.cuda.is_available()
|
|
|
|
rnn_runners = get_nn_runners(*args.rnns)
|
|
|
|
should_test_varlen_lstms = args.variable_lstms
|
|
test_args = vars(args)
|
|
del test_args["rnns"]
|
|
del test_args["variable_lstms"]
|
|
|
|
if should_test_varlen_lstms:
|
|
test_vl_py(**test_args)
|
|
|
|
for name, creator, context in rnn_runners:
|
|
with context():
|
|
print(f"testing {name}...")
|
|
test_rnns(creator, pytorch_lstm_creator, **test_args)
|