# Owner(s): ["module: dynamo"] # flake8: noqa import collections import functools import inspect import itertools import operator from typing import Any import torch import torch._dynamo.test_case import torch._dynamo.testing from torch import sub from torch._dynamo.testing import requires_static_shapes from torch.nn import functional as F tensor_for_import_testing = torch.ones(10, 10) d = torch.ones(10, 10) e = torch.nn.Linear(10, 10) flag = True def constant3(a, b): return a - b + (1.0 + 2) def func_with_default(a, b, some_default_arg=True): if some_default_arg: return a - b def make_test(fn): nargs = len(inspect.signature(fn).parameters) def test_fn(self): return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs) return test_fn @torch.jit.script_if_tracing def inline_script_if_tracing(x): return x + 1.2 @torch.jit.ignore def inline_ignore(x): return x + 3.4 @torch.jit.unused def inline_unused(x): return x + 5.6 class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) x = inline_ignore(x) x = inline_unused(x) return @make_test def test_add(a, b): return a + b @make_test def test_is_not_null(a, b): if a is not None and b is not None: return a + b @make_test def test_constant1(a, b, c): return a - b * c + 1.0 @make_test def test_constant2(a, b, c): return a - b * c + 1 @make_test def test_constant3(a): b = 1 c = 2 d = 3 return b + c - d + a @make_test def test_constant4(a, b): c = 2 d = 3 if c > d: return a - b return b - a @make_test def test_finfo(a, b): if torch.iinfo(torch.int32).bits == 32: return torch.finfo(a.dtype).min * b @make_test def test_globalfn(a, b): return sub(a, b) @make_test def test_viatorch(a, b): return torch.sub(a, b) @make_test def test_viamethod(a, b): return a.sub(b) @make_test def test_indirect1(a, b): t = a.sub return t(b) @make_test def test_indirect2(a, b): t = a.sub args = (b,) return t(*args) @make_test def test_indirect3(a, b): t = a.sub args = (b,) kwargs = {} return t(*args, **kwargs) @make_test def test_methodcall1(a, b, c): return constant3(a, b) * c @make_test def test_methodcall2(a, b): return constant3(a=b, b=a) + 1 @make_test def test_methodcall3(a, b): return constant3(a, b=1.0) + b @make_test def test_device_constant(a): return a + torch.ones(1, device=torch.device("cpu")) @make_test def test_tuple1(a, b): args = (a, b) return sub(*args) @make_test def test_tuple2(a, b): args = [a, b] return sub(*args) @make_test def test_is_in_onnx_export(x, y): if torch.onnx.is_in_onnx_export(): return x - 1 else: return y + 1 @make_test def test_is_fx_tracing(x, y): if torch.fx._symbolic_trace.is_fx_tracing(): return x - 1 else: return y + 1 @make_test def test_listarg1(a, b): return torch.cat([a, b]) @make_test def test_listarg2(a, b): return torch.cat((a, b), dim=0) @make_test def test_listarg3(a, b): kwargs = {"tensors": (a, b), "dim": 0} return torch.cat(**kwargs) @make_test def test_listarg4(a, b): return torch.cat(tensors=[a, b], dim=0) @make_test def test_listarg5(a, b): args = [(a, b)] kwargs = {"dim": 0} return torch.cat(*args, **kwargs) @make_test def test_slice1(a): return a[5] @make_test def test_slice2(a): return a[:5] @make_test def test_slice3(a): return a[5:] @make_test def test_slice4(a): return a[2:5] @make_test def test_slice5(a): return a[::2] @make_test def test_slice6(a): return torch.unsqueeze(a, 0)[:, 2:] @make_test def test_unpack1(a): a, b = a[:5], a[5:] return a - b @make_test def test_unpack2(a): packed = [a[:5], a[5:]] a, b = packed return a - b @make_test def test_unpack3(a): packed = (a[:5], a[5:]) a, b = packed return a - b @make_test def test_fn_with_self_set(a, b): # avg_pool2d is an odd one with __self__ set return F.avg_pool2d( torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1 ) @make_test def test_return_tuple1(a, b): return (a - b, b - a, a, b) @make_test def test_globalvar(a, b): return a - b + d @make_test def test_globalmodule(x): return e(x) @make_test def test_inline_with_default(a, b, c): return func_with_default(a, b) * c @make_test def test_inner_function(x): def fn(x): return torch.add(x, x) return fn(x) @make_test def test_transpose_for_scores(x): new_x_shape = x.size()[:-1] + (2, 5) x = x.view(*new_x_shape) return x.permute(0, 2, 1) @make_test def test_return_tuple2(x): return (torch.add(x, x), x) @make_test def test_load_global_bool(x): if flag: return torch.add(x, x) else: return x @make_test def test_len_tensor(x): z = len(x) return torch.add(x, z) @make_test def test_len_constant_list(x): z = len([1, 2, 3]) return torch.add(x, z) @make_test def test_len_constant_dict(x): z = len({"foo": "bar"}) return torch.add(x, z) @make_test def test_dict_copy(x): z = dict({"foo": x + 1}) return z @make_test def test_len_constant_misc_iterables(x): a = len((1, 2, 3)) b = len("test str") c = a + b return torch.add(x, c) @make_test def test_float(x): y = float(1.2) y += float("1.2") return torch.add(x, y) @make_test def test_dtype(x): if x.dtype == torch.float32: return x + 1 @make_test def test_device(x): if not x.is_cuda: return x + 1 @make_test def test_ndim(x): if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2: return x + 1 @make_test def test_is_sparse(x): if not x.is_sparse: return x + 1 @requires_static_shapes @make_test def test_shape1(x): if x.shape[0] == 10: return x + 1 @requires_static_shapes @make_test def test_shape2(x): if x.size(1) == 10: return x + 1 @make_test def test_del(a, b): c = a + 1 d = c + 2 del c, a return b + d @requires_static_shapes @make_test def test_chunks1(x): chunk_size = 5 assert x.shape[0] % chunk_size == 0 assert x.shape[0] // chunk_size == 2 return x[:chunk_size] - x[chunk_size:] @make_test def test_import1(x, y): import torch from torch import sub return sub(torch.add(x, y), y) @make_test def test_return_dict(x, y): z = [x + y, y, False] return {"x": x, "z": z, "a": x, "b": z, "c": x} @make_test def test_return_dict2(x, y): tmp = {"x": x} tmp["z"] = [x + y, y] tmp["y"] = y tmp["z"].append(False) return tmp @make_test def test_funcdef_closure(x, y): x = x + y + 1.0 def inner(z): nonlocal x, y y = x + z + 20.0 x = y + z + 10.0 inner(2.0) inner(3.0) return x, y @make_test def test_module_constant(x, y): r = x + y for i in range(torch._dynamo.testing.three): r = r / y return r @make_test def test_inline_softmax(x, y): # This is common in sme huggingface models return torch.nn.Softmax(dim=-1)(x + y * 2) @make_test def test_dtype_compare(a, b): if a.dtype == torch.float16: return a + 10 if a.dtype == torch.float32: return a - b * 32 @make_test def test_build_list_unpack(a, b): it1 = (x + 1 for x in (a, b)) it2 = (x - 1 for x in (a, b)) return torch.cat([*it1, *it2], dim=-1) @make_test def test_tensor_len(a, b): return a + b + len(a) + b.__len__() @make_test def test_pop(a, b): ll = [a, b] ll.append(a + 1) ll.extend( [ b + 2, a + b, ] ) ll.pop(-1) ll.pop(0) ll.pop() v1, v2 = ll return v1 - v2 @make_test def test_list_convert(a, b): ll = [a + 2, b] ll = tuple(ll) tmp = b + 3 ll = list(ll) v1, v2 = ll return v1 - v2 + tmp @make_test def test_list_add(a, b): l1 = (a, b) l2 = () # being a LOAD_CONST in the bytecode l3 = l1 + l2 return l3[0] + l3[1] @make_test def test_startswith(a, b): x = a + b if "foobar".startswith("foo") and "test" in constant3.__module__: x = x + 1 return x @make_test def test_dict_ops(a, b): tmp = {"a": a + 1, "b": b + 2} v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4) tmp.update({"d": 3}) tmp["c"] = v + tmp["d"] if "c" in tmp and "missing" not in tmp: return tmp["c"] - tmp["a"] + len(tmp) def test_dict_param_keys(self): a_param = torch.nn.Parameter(torch.ones([4, 4])) def fn(a): tmp = {"a": a, a_param: 3} return tmp["a"] + tmp[a_param] test = make_test(fn) test(self) def test_default_dict(self): dd = collections.defaultdict(dict) param = torch.nn.Parameter(torch.ones([2, 2])) def fn(x): dd["a"] = x + 1 dd[param] = 123 dd["c"] = x * 2 return dd["b"], dd test = make_test(fn) test(self) @make_test def test_min_max(a, b): c = a + b a = a.sum() b = b.sum() a = min(max(a, 0), 1) b = max(0, min(1, b)) return max(a, b) - min(a, b) + c @make_test def test_map_sum(a, b, c, d): return sum(map(lambda x: x + 1, [a, b, c, d])) @make_test def test_reduce(a, b, c, d): return functools.reduce(operator.add, [a, b, c, d]) @make_test def test_tuple_contains(a, b): v1 = "a" v2 = "b" v3 = "c" vals1 = (v1, v2, v3) vals2 = ("d", "e", "f") if "a" in vals1 and "b" not in vals2: return a + b return a - b @make_test def test_tuple_iadd(a, b): output = (a, b) output += (a + b, a - b) return output @make_test def test_unpack_ex1(x): output = (x, x + 1, x + 2, x + 3) a, b, *cd = output return a - b / cd[0] @make_test def test_unpack_ex2(x): output = (x, x + 1, x + 2, x + 3) *ab, c, d = output return c - d / ab[0] @make_test def test_unpack_ex3(x): output = (x, x + 1, x + 2, x + 3) a, *bc, d = output return a - d / bc[0] @make_test def test_const_tuple_add1(x): output = (x, x + 1, x + 2, x + 3) output = () + output + () return output[2] + output[3] @make_test def test_const_tuple_add2(x): output = (x, x + 1, x + 2, x + 3) output = (None,) + output + (None,) return output[2] + output[3] @make_test def test_list_truth(a, b): tmp = [1, 2, 3] if tmp: return a + b else: return a - b @make_test def test_list_reversed(a, b): tmp = [a + 1, a + 2, a + 3] return a + b + next(iter(reversed(tmp))) @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] tmp.clear() tmp.append(a + b) return tmp @make_test def test_islice_chain(a, b): tmp1 = [a + 1, a + 2] tmp2 = [a + 3, a + 4] a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3)) c = next(itertools.islice(tmp1, 1, None)) return a - b / c @make_test def test_is_quantized(a, b): if not a.is_quantized: return a + b @make_test def test_fstrings1(a, b): x = 1.229 tmp = f"{x:.2f} bar" if tmp.startswith("1.23"): return a + b @requires_static_shapes @make_test def test_fstrings2(x): tmp = f"{x.shape[0]} bar" if tmp.startswith("10"): return x + 1 @make_test def test_fstrings3(x): tmp = f"{x.__class__.__name__} foo" if tmp.startswith("Tensor"): return x + 1 @requires_static_shapes @make_test def test_tensor_new_with_size(x): y = torch.rand(5, 8) z = x.new(y.size()) assert z.size() == y.size() @requires_static_shapes @make_test def test_tensor_new_with_shape(x): y = torch.rand(5, 8) z = x.new(y.shape) assert z.size() == y.size() @make_test def test_jit_annotate(x): y = torch.jit.annotate(Any, x + 1) return y + 2 @requires_static_shapes @make_test def test_is_contiguous_memory_format(tensor): if torch.jit.is_scripting(): return None elif tensor.is_contiguous(memory_format=torch.contiguous_format): return tensor + 1 @make_test def test_list_slice_assignment(x): m = [1, 2, 3, 4] m[1:] = [6] * (len(m) - 1) return x + 1 # # This is to test the new syntax for pattern matching # # ("match ... case ...") added on python 3.10. # # Uncomment these test cases if you run on 3.10+ # @make_test # def test_match_sequence(a): # point = (5, 8) # match point: # case (0, 0): # return a # case (0, y): # return a - y # case (x, 0): # return a + x # case (x, y): # return a + x - y # @make_test # def test_match_mapping_and_match_keys(x): # param = {"a": 0.5} # match param: # case {"a": param}: # return x * param # case {"b": param}: # return x / param if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()