mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120246 Approved by: https://github.com/shunting314
179 lines
6.6 KiB
Python
179 lines
6.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from functorch import make_fx
|
|
from torch._dynamo import debug_utils
|
|
from torch._dynamo.debug_utils import aot_graph_input_parser
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
|
|
|
f32 = torch.float32
|
|
i64 = torch.int64
|
|
i32 = torch.int32
|
|
|
|
|
|
class TestDebugUtils(TestCase):
|
|
def test_cast_model_to_fp64_dtype_args(self):
|
|
# Test that dtype arguments are converted to fp64
|
|
|
|
def fn(x):
|
|
return (
|
|
torch.ops.prims.convert_element_type(x, torch.float16),
|
|
x.to(torch.float16),
|
|
torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
|
|
x.new_empty(x.shape),
|
|
)
|
|
|
|
x = torch.randn(32, device="cpu")
|
|
decomps = torch._decomp.core_aten_decompositions()
|
|
fx = make_fx(fn, decomposition_table=decomps)(x)
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
|
|
self.assertEqual(fp64_examples, (x.to(torch.float64),))
|
|
|
|
self.assertExpectedInline(
|
|
fx.code.lstrip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
|
|
_to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None
|
|
full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
|
empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
|
|
return (convert_element_type, _to_copy, full, empty)
|
|
""", # NOQA: B950
|
|
)
|
|
|
|
@requires_cuda
|
|
def test_aot_graph_parser(self):
|
|
from torch import device
|
|
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[1001, 6]",
|
|
primals_2: "f32[1001]",
|
|
primals_3: "f32[1001, 64]",
|
|
primals_4: "f32[4190]",
|
|
primals_5: "f32[4190]",
|
|
primals_6: "f32[1739, 4190]",
|
|
primals_48: "f32[6144, 4191]",
|
|
):
|
|
_tensor_constant0: "i64[4190]" = self._tensor_constant0
|
|
lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
|
|
_tensor_constant0
|
|
)
|
|
_tensor_constant0 = None
|
|
index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
|
|
primals_48, [None, lift_fresh_copy]
|
|
)
|
|
lift_fresh_copy = None
|
|
|
|
_tensor_constant1: "i64[6]" = self._tensor_constant1
|
|
lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
|
|
_tensor_constant1
|
|
)
|
|
_tensor_constant1 = None
|
|
index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
|
|
primals_48, [None, lift_fresh_copy_1]
|
|
)
|
|
primals_48 = lift_fresh_copy_1 = None
|
|
permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
|
|
primals_1 = None
|
|
addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
|
|
primals_2, index_1, permute
|
|
)
|
|
primals_2 = permute = None
|
|
amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
|
|
sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
|
|
exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
|
|
sub = None
|
|
sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
|
|
div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
|
|
exp = None
|
|
|
|
full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
|
|
[6144, 1001],
|
|
1,
|
|
dtype=torch.int32,
|
|
layout=torch.strided,
|
|
device=device(type="cuda", index=0),
|
|
pin_memory=False,
|
|
)
|
|
|
|
iota: "i32[1001]" = torch.ops.prims.iota.default(
|
|
1001,
|
|
start=0,
|
|
step=1,
|
|
dtype=torch.int32,
|
|
device=device(type="cuda"),
|
|
requires_grad=False,
|
|
)
|
|
|
|
mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
|
|
full_default = iota = None
|
|
|
|
iota_1: "i32[6144]" = torch.ops.prims.iota.default(
|
|
6144,
|
|
start=0,
|
|
step=1001,
|
|
dtype=torch.int32,
|
|
device=device(type="cuda", index=0),
|
|
requires_grad=False,
|
|
)
|
|
view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
|
|
mul = None
|
|
view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
|
|
div = None
|
|
_embedding_bag = torch.ops.aten._embedding_bag.default(
|
|
primals_3, view, iota_1, False, 0, False, view_1
|
|
)
|
|
|
|
return _embedding_bag
|
|
|
|
kwargs = aot_graph_input_parser(forward, device="cuda")
|
|
# runs successfully
|
|
forward(**kwargs)
|
|
|
|
@requires_cuda
|
|
def test_sym_aot_graph_parser(self):
|
|
def forward(
|
|
self,
|
|
primals_1: "f32[1001, 6]", # noqa: F821
|
|
primals_2: "f32[s0]", # noqa: F821
|
|
primals_3: "Sym(s0)", # noqa: F821,
|
|
primals_4: "f32[s1]", # noqa: F821,
|
|
primals_5: "Sym(s1)", # noqa: F821,
|
|
):
|
|
_tensor_constant0: "i64[4190]" = self._tensor_constant0
|
|
|
|
kwargs = aot_graph_input_parser(
|
|
forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5
|
|
)
|
|
|
|
self.assertEqual(list(kwargs["primals_2"].shape), [10])
|
|
self.assertEqual(kwargs["primals_3"], 10)
|
|
|
|
self.assertEqual(list(kwargs["primals_4"].shape), [5])
|
|
self.assertEqual(kwargs["primals_5"], 5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|