Revert "Re-enable FakeTensor caching for SymInts (#152662)"

This reverts commit 7d11c61c26c596076613aa0111892f7cbccae32e.

Reverted https://github.com/pytorch/pytorch/pull/152662 on behalf of https://github.com/malfet due to Looks like it broke bunch of inductor tests, see 187d38185e/1 ([comment](https://github.com/pytorch/pytorch/pull/152662#issuecomment-2910293593))
This commit is contained in:
PyTorch MergeBot
2025-05-26 17:13:22 +00:00
parent 187d38185e
commit 3f64502c98
6 changed files with 34 additions and 38 deletions

View File

@ -28,7 +28,8 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
opt_device_type = at::getAccelerator(false);
}
if (opt_device_type.has_value()) {
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
return at::globalContext().getPinnedMemoryAllocator(
opt_device_type.value());
} else {
TORCH_CHECK(
false, "Need to provide pin_memory allocator to use pin memory.")

View File

@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,4284000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38360000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
update_hint_regression,compile_time_instruction_count,1703000000,0.02
update_hint_regression,compile_time_instruction_count,1681000000,0.02
float_args,compile_time_instruction_count,454300000,0.015
float_args,compile_time_instruction_count,449800000,0.015
@ -66,7 +66,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1883000000,0.015
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015

1 add_loop_eager compile_time_instruction_count 2953000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 4284000000 5808000000 0.025
3 add_loop_inductor compile_time_instruction_count 29370000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 38360000000 44010000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 939900000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18140000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16220000000 0.015
10 update_hint_regression compile_time_instruction_count 1703000000 1681000000 0.02
11 float_args compile_time_instruction_count 454300000 449800000 0.015
12 sum_floordiv_regression compile_time_instruction_count 998600000 0.015
13 symint_sum compile_time_instruction_count 3252000000 0.015
14 symint_sum_loop compile_time_instruction_count 4262000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2091000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5981000000 0.015
34
35
36
37
38
39
40
41
42
43
44
66
67
68
69
70
71
72

View File

@ -1113,11 +1113,9 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
elif isinstance(result, (tuple, list)):
# Preserve the original type (tuple or list)
wrapped = [
(
cls(x, quant_type=quant_type)
if isinstance(x, torch.Tensor)
else x
)
cls(x, quant_type=quant_type)
if isinstance(x, torch.Tensor)
else x
for x in result
]
return type(result)(wrapped)
@ -2537,9 +2535,9 @@ class GraphModule(torch.nn.Module):
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)
@ -2593,9 +2591,9 @@ class GraphModule(torch.nn.Module):
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
""", # noqa: B950
)

View File

@ -167,7 +167,6 @@ class _FakeTensorModeSerializer:
def __init__(self, fake_mode: FakeTensorMode) -> None:
self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
self.shape_env = fake_mode.shape_env
@contextlib.contextmanager
def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]:
@ -248,7 +247,6 @@ class _WireProtocolOutput:
metrics: CachedMetricsDeltas
logs: list[logging.LogRecord]
warning_replay: Optional[list[warnings.WarningMessage]]
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]
def serialize(self) -> _WireProtocolPickledOutput:
"""
@ -548,11 +546,7 @@ class _SerializedFxCompile(FxCompile):
logs = captured_logs.finish()
return _WireProtocolOutput(
output_graph,
metrics.get_deltas(),
logs,
warning_replay,
fake_mode.shape_env,
output_graph, metrics.get_deltas(), logs, warning_replay
).serialize()

View File

@ -1521,10 +1521,6 @@ class FakeTensorMode(TorchDispatchMode):
# where it wasn't seen on a previous instance of the same op.
self.shape_env.settings if self.shape_env else None,
]
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
@ -1636,6 +1632,10 @@ class FakeTensorMode(TorchDispatchMode):
raise _BypassDispatchCache("constant attribute")
if is_sparse_any(arg):
raise _BypassDispatchCache(f"{arg.layout} tensor")
# FIXME: For now back out caching when there are symbolic nbytes
# - this doesn't seem to play nice with set(). See T196779132 for examples.
if isinstance(arg.untyped_storage().nbytes(), SymInt):
raise _BypassDispatchCache("symbolic nbytes")
metadata = extract_tensor_metadata(arg)
metadata._flatten_into(result, self, state)
elif isinstance(arg, Tensor):
@ -1937,7 +1937,11 @@ class FakeTensorMode(TorchDispatchMode):
if entry.is_output_tuple:
outputs = [
self._get_output_tensor_from_cache_entry(
state, output_info, key, func, args
state,
output_info,
key,
func,
args,
)
for output_info in entry.output_infos
]

View File

@ -244,7 +244,7 @@ class SymIntEqByExpr:
def _nested_int_aware_sort(
tup: tuple[IntLikeType, int],
tup: tuple[IntLikeType, int]
) -> tuple[int, IntLikeType, int]:
return (
# Order nested ints by their coefficients.
@ -1380,7 +1380,7 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
def guard_scalar(
a: Union[SymBool, SymInt, SymFloat, int, bool, float],
a: Union[SymBool, SymInt, SymFloat, int, bool, float]
) -> Union[bool, int, float]:
if isinstance(a, (SymBool, bool)):
return guard_bool(a)
@ -2036,7 +2036,7 @@ class TrackedFake:
def is_symbolic(
val: Union[int, SymInt, float, SymFloat, bool, SymBool],
val: Union[int, SymInt, float, SymFloat, bool, SymBool]
) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]:
if isinstance(val, (int, float, bool)):
return False
@ -2272,7 +2272,7 @@ def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr:
def cast_symbool_to_symint_guardless(
symbool: Union[bool, torch.SymBool],
symbool: Union[bool, torch.SymBool]
) -> Union[int, torch.SymInt]:
if isinstance(symbool, bool):
return 1 if symbool else 0
@ -6927,6 +6927,8 @@ class ShapeEnv:
},
)
@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def evaluate_expr(
self,
orig_expr: sympy.Basic,
@ -7016,8 +7018,7 @@ class ShapeEnv:
):
return orig_expr
# Don't track this one. (Because this cache is inside this function the
# cache only lasts for the invocation of this function call)
# Don't track this one
@functools.lru_cache(None)
def compute_concrete_val() -> sympy.Basic:
if hint is None:
@ -7098,11 +7099,9 @@ class ShapeEnv:
if static_expr is not None:
self.log.debug(
"eval %s == %s [statically known]",
(
f"size_oblivious({orig_expr})"
if size_oblivious
else size_oblivious
),
f"size_oblivious({orig_expr})"
if size_oblivious
else size_oblivious,
static_expr,
)
if (