Files
pytorch/test/inductor/indirect_assert_helper.py
xinan.lin 07f76517e7 [Inductor][WIndows] Fix Windows test case failure. (#161497)
Fixes windows test case failures:
- TritonCodeGenTests.test_inductor_sequence_nr
- TritonCodeGenTests.test_indirect_device_assert
- CompiledOptimizerTests.test_static_address_finalizer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161497
Approved by: https://github.com/jansel
2025-08-28 12:40:42 +00:00

83 lines
1.8 KiB
Python

import sys
import torch
from torch.testing._internal.inductor_utils import GPU_TYPE
def first_arg(x, y):
return x[y]
def second_arg(x, y):
return x[:, y]
def same_pm_one(x, y):
return x[y + 1, y - 1]
def same_pp_one(x, y):
return x[y + 1, y + 1]
def store(x, y, z):
x[y + 1, y + 1] = z
def upper1(x):
b = torch.arange(4, device=x.device)
return x[b]
def lower1(x):
b = x.new_full((), -4, dtype=torch.int64)
return x[b]
def upper2(x):
b = x.new_full((), 4, dtype=torch.int64)
return x[b]
def lower2(x):
b = x.new_zeros((), dtype=torch.int64)
return x[b - 4]
if __name__ == "__main__":
fns = [
name
for name, obj in locals().items()
if callable(obj) and obj.__module__ == __name__
]
_, fn_name, dims, dyn_shape, one_size = sys.argv
assert fn_name in fns
assert one_size in ("True", "False")
one_size = one_size == "True"
assert dims in ("2", "3")
shape_x = [3, 2, 4] if dims == "3" else [3, 2]
if one_size:
assert fn_name == "first_arg", (
"only first_arg can be tested for a special case of 1-size tensor"
)
shape_x[0] = 1
assert dyn_shape in ("True", "False")
dynamic_shapes = dyn_shape == "True"
x = torch.randn(shape_x, device=GPU_TYPE)
y = torch.arange(4, device=GPU_TYPE)
fn = vars()[fn_name]
fn = torch.compile(dynamic=dynamic_shapes)(fn)
if fn_name == "store":
shape = (y.numel(),) + x.shape[2:]
z = torch.randn(shape, device=GPU_TYPE)
fn(x, y, z)
# On Windows, Python will optimize away a function call if its updated value is not used.
# Touch the memory of x so that the fn(x, y, z) will not be optimized away
print(x)
elif fn_name in ("upper1", "upper2", "lower1", "lower2"):
print(fn(x))
else:
print(fn(x, y))