mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129763 Approved by: https://github.com/jansel
80 lines
1.6 KiB
Python
80 lines
1.6 KiB
Python
import sys
|
|
|
|
import torch
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
|
|
|
|
|
def first_arg(x, y):
|
|
return x[y]
|
|
|
|
|
|
def second_arg(x, y):
|
|
return x[:, y]
|
|
|
|
|
|
def same_pm_one(x, y):
|
|
return x[y + 1, y - 1]
|
|
|
|
|
|
def same_pp_one(x, y):
|
|
return x[y + 1, y + 1]
|
|
|
|
|
|
def store(x, y, z):
|
|
x[y + 1, y + 1] = z
|
|
|
|
|
|
def upper1(x):
|
|
b = torch.arange(4, device=x.device)
|
|
return x[b]
|
|
|
|
|
|
def lower1(x):
|
|
b = x.new_full((), -4, dtype=torch.int64)
|
|
return x[b]
|
|
|
|
|
|
def upper2(x):
|
|
b = x.new_full((), 4, dtype=torch.int64)
|
|
return x[b]
|
|
|
|
|
|
def lower2(x):
|
|
b = x.new_zeros((), dtype=torch.int64)
|
|
return x[b - 4]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fns = [
|
|
name
|
|
for name, obj in locals().items()
|
|
if callable(obj) and obj.__module__ == __name__
|
|
]
|
|
|
|
_, fn_name, dims, dyn_shape, one_size = sys.argv
|
|
assert fn_name in fns
|
|
assert one_size in ("True", "False")
|
|
one_size = one_size == "True"
|
|
assert dims in ("2", "3")
|
|
shape_x = [3, 2, 4] if dims == "3" else [3, 2]
|
|
if one_size:
|
|
assert (
|
|
fn_name == "first_arg"
|
|
), "only first_arg can be tested for a special case of 1-size tensor"
|
|
shape_x[0] = 1
|
|
assert dyn_shape in ("True", "False")
|
|
dynamic_shapes = dyn_shape == "True"
|
|
|
|
x = torch.randn(shape_x, device=GPU_TYPE)
|
|
y = torch.arange(4, device=GPU_TYPE)
|
|
fn = vars()[fn_name]
|
|
fn = torch.compile(dynamic=dynamic_shapes)(fn)
|
|
if fn_name == "store":
|
|
shape = (y.numel(),) + x.shape[2:]
|
|
z = torch.randn(shape, device=GPU_TYPE)
|
|
fn(x, y, z)
|
|
elif fn_name in ("upper1", "upper2", "lower1", "lower2"):
|
|
fn(x)
|
|
else:
|
|
fn(x, y)
|