Migrating some more callsites (#163580)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163580
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-10-17 10:07:14 -07:00
committed by PyTorch MergeBot
parent 22ae059d32
commit c73f5080de
5 changed files with 26 additions and 17 deletions

View File

@ -3,6 +3,7 @@ import sys
from benchmark_base import BenchmarkBase from benchmark_base import BenchmarkBase
import torch import torch
from torch._dynamo.utils import CompileTimeInstructionCounter
class Benchmark(BenchmarkBase): class Benchmark(BenchmarkBase):
@ -32,7 +33,11 @@ class Benchmark(BenchmarkBase):
def _work(self): def _work(self):
# enable_cpp_symbolic_shape_guards has impact on this benchmark # enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency. # Keep using False value for consistency.
with torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False): with (
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
torch._export.config.patch(use_new_tracer_experimental=True),
CompileTimeInstructionCounter.record(),
):
torch.export.export(self.m, (self.input,), strict=True) torch.export.export(self.m, (self.input,), strict=True)

View File

@ -38,7 +38,7 @@ update_hint_regression,compile_time_instruction_count,1719000000,0.1
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1 sum_floordiv_regression,compile_time_instruction_count,3686995725,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.1
38
39
40
41
42
43
44

View File

@ -2712,19 +2712,20 @@ def forward(self, x):
torch._dynamo.exc.UserError, torch._dynamo.exc.UserError,
".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3", ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
): ):
torch.export.export(bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True) with torch._export.config.patch(use_new_tracer_experimental=True):
torch.export.export(
bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True
)
y = torch.randn(10, 3, 3) y = torch.randn(10, 3, 3)
ebar = torch.export.export( with torch._export.config.patch(use_new_tracer_experimental=True):
bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True ebar = torch.export.export(
) bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True
self.assertEqual( )
[
str(node.meta["val"].shape) for node in ebar.graph_module.graph.nodes:
for node in ebar.graph_module.graph.nodes if node.op == "placeholder":
if node.op == "placeholder" shape = node.meta["val"].shape
], self.assertEqual(shape[1], shape[2])
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
)
@torch._dynamo.config.patch( @torch._dynamo.config.patch(
capture_dynamic_output_shape_ops=True, capture_dynamic_output_shape_ops=True,

View File

@ -157,7 +157,10 @@ class AOTIRunnerUtil:
# This should really be the default behavior of torch.export.export # This should really be the default behavior of torch.export.export
model = WrapperModule(model) model = WrapperModule(model)
with torch.no_grad(): with (
torch.no_grad(),
torch._export.config.patch(use_new_tracer_experimental=True),
):
# strict=False needs extra migration work # strict=False needs extra migration work
ep = torch.export.export( ep = torch.export.export(
model, model,

View File

@ -92,13 +92,13 @@ class TestMemoryPlanning(TestCase):
) )
FileCheck().check( FileCheck().check(
"int64_t int_array_0[] = {24L + align(12L*s77), };" "int64_t int_array_0[] = {24L + align(12L*s6), };"
).check_next("int64_t int_array_1[] = {1L, };").check_next( ).check_next("int64_t int_array_1[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;" "AtenTensorHandle pool1_handle;"
).check_next( ).check_next(
"aoti_torch_empty_strided(1, int_array_0, int_array_1," "aoti_torch_empty_strided(1, int_array_0, int_array_1,"
).check_next("RAIIAtenTensorHandle pool1(pool1_handle);").check_next( ).check_next("RAIIAtenTensorHandle pool1(pool1_handle);").check_next(
"int64_t int_array_2[] = {s77, 3L};" "int64_t int_array_2[] = {s6, 3L};"
).check_next("int64_t int_array_3[] = {3L, 1L};").check_next( ).check_next("int64_t int_array_3[] = {3L, 1L};").check_next(
"AtenTensorHandle tmp_tensor_handle_0;" "AtenTensorHandle tmp_tensor_handle_0;"
).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code) ).check_next("aoti_torch__alloc_from_pool(pool1, 0").run(code)