mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Re-enable FakeTensor caching for SymInts (#152662)"
This reverts commit 7d11c61c26c596076613aa0111892f7cbccae32e.
Reverted https://github.com/pytorch/pytorch/pull/152662 on behalf of https://github.com/malfet due to Looks like it broke bunch of inductor tests, see 187d38185e/1 ([comment](https://github.com/pytorch/pytorch/pull/152662#issuecomment-2910293593))
This commit is contained in:
@ -28,7 +28,8 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
||||
opt_device_type = at::getAccelerator(false);
|
||||
}
|
||||
if (opt_device_type.has_value()) {
|
||||
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
|
||||
return at::globalContext().getPinnedMemoryAllocator(
|
||||
opt_device_type.value());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Need to provide pin_memory allocator to use pin memory.")
|
||||
|
||||
@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4284000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5808000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38360000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44010000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1703000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1681000000,0.02
|
||||
|
||||
|
||||
|
||||
float_args,compile_time_instruction_count,454300000,0.015
|
||||
float_args,compile_time_instruction_count,449800000,0.015
|
||||
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ aotdispatcher_partitioner_cpu,compile_time_instruction_count,8585000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1883000000,0.015
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1113,11 +1113,9 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
elif isinstance(result, (tuple, list)):
|
||||
# Preserve the original type (tuple or list)
|
||||
wrapped = [
|
||||
(
|
||||
cls(x, quant_type=quant_type)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
)
|
||||
cls(x, quant_type=quant_type)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
for x in result
|
||||
]
|
||||
return type(result)(wrapped)
|
||||
@ -2537,9 +2535,9 @@ class GraphModule(torch.nn.Module):
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2593,9 +2591,9 @@ class GraphModule(torch.nn.Module):
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
@ -167,7 +167,6 @@ class _FakeTensorModeSerializer:
|
||||
|
||||
def __init__(self, fake_mode: FakeTensorMode) -> None:
|
||||
self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
|
||||
self.shape_env = fake_mode.shape_env
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]:
|
||||
@ -248,7 +247,6 @@ class _WireProtocolOutput:
|
||||
metrics: CachedMetricsDeltas
|
||||
logs: list[logging.LogRecord]
|
||||
warning_replay: Optional[list[warnings.WarningMessage]]
|
||||
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledOutput:
|
||||
"""
|
||||
@ -548,11 +546,7 @@ class _SerializedFxCompile(FxCompile):
|
||||
logs = captured_logs.finish()
|
||||
|
||||
return _WireProtocolOutput(
|
||||
output_graph,
|
||||
metrics.get_deltas(),
|
||||
logs,
|
||||
warning_replay,
|
||||
fake_mode.shape_env,
|
||||
output_graph, metrics.get_deltas(), logs, warning_replay
|
||||
).serialize()
|
||||
|
||||
|
||||
|
||||
@ -1521,10 +1521,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# where it wasn't seen on a previous instance of the same op.
|
||||
self.shape_env.settings if self.shape_env else None,
|
||||
]
|
||||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
@ -1636,6 +1632,10 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
raise _BypassDispatchCache("constant attribute")
|
||||
if is_sparse_any(arg):
|
||||
raise _BypassDispatchCache(f"{arg.layout} tensor")
|
||||
# FIXME: For now back out caching when there are symbolic nbytes
|
||||
# - this doesn't seem to play nice with set(). See T196779132 for examples.
|
||||
if isinstance(arg.untyped_storage().nbytes(), SymInt):
|
||||
raise _BypassDispatchCache("symbolic nbytes")
|
||||
metadata = extract_tensor_metadata(arg)
|
||||
metadata._flatten_into(result, self, state)
|
||||
elif isinstance(arg, Tensor):
|
||||
@ -1937,7 +1937,11 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if entry.is_output_tuple:
|
||||
outputs = [
|
||||
self._get_output_tensor_from_cache_entry(
|
||||
state, output_info, key, func, args
|
||||
state,
|
||||
output_info,
|
||||
key,
|
||||
func,
|
||||
args,
|
||||
)
|
||||
for output_info in entry.output_infos
|
||||
]
|
||||
|
||||
@ -244,7 +244,7 @@ class SymIntEqByExpr:
|
||||
|
||||
|
||||
def _nested_int_aware_sort(
|
||||
tup: tuple[IntLikeType, int],
|
||||
tup: tuple[IntLikeType, int]
|
||||
) -> tuple[int, IntLikeType, int]:
|
||||
return (
|
||||
# Order nested ints by their coefficients.
|
||||
@ -1380,7 +1380,7 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
|
||||
|
||||
|
||||
def guard_scalar(
|
||||
a: Union[SymBool, SymInt, SymFloat, int, bool, float],
|
||||
a: Union[SymBool, SymInt, SymFloat, int, bool, float]
|
||||
) -> Union[bool, int, float]:
|
||||
if isinstance(a, (SymBool, bool)):
|
||||
return guard_bool(a)
|
||||
@ -2036,7 +2036,7 @@ class TrackedFake:
|
||||
|
||||
|
||||
def is_symbolic(
|
||||
val: Union[int, SymInt, float, SymFloat, bool, SymBool],
|
||||
val: Union[int, SymInt, float, SymFloat, bool, SymBool]
|
||||
) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]:
|
||||
if isinstance(val, (int, float, bool)):
|
||||
return False
|
||||
@ -2272,7 +2272,7 @@ def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr:
|
||||
|
||||
|
||||
def cast_symbool_to_symint_guardless(
|
||||
symbool: Union[bool, torch.SymBool],
|
||||
symbool: Union[bool, torch.SymBool]
|
||||
) -> Union[int, torch.SymInt]:
|
||||
if isinstance(symbool, bool):
|
||||
return 1 if symbool else 0
|
||||
@ -6927,6 +6927,8 @@ class ShapeEnv:
|
||||
},
|
||||
)
|
||||
|
||||
@lru_cache(256)
|
||||
@record_shapeenv_event(save_tracked_fakes=True)
|
||||
def evaluate_expr(
|
||||
self,
|
||||
orig_expr: sympy.Basic,
|
||||
@ -7016,8 +7018,7 @@ class ShapeEnv:
|
||||
):
|
||||
return orig_expr
|
||||
|
||||
# Don't track this one. (Because this cache is inside this function the
|
||||
# cache only lasts for the invocation of this function call)
|
||||
# Don't track this one
|
||||
@functools.lru_cache(None)
|
||||
def compute_concrete_val() -> sympy.Basic:
|
||||
if hint is None:
|
||||
@ -7098,11 +7099,9 @@ class ShapeEnv:
|
||||
if static_expr is not None:
|
||||
self.log.debug(
|
||||
"eval %s == %s [statically known]",
|
||||
(
|
||||
f"size_oblivious({orig_expr})"
|
||||
if size_oblivious
|
||||
else size_oblivious
|
||||
),
|
||||
f"size_oblivious({orig_expr})"
|
||||
if size_oblivious
|
||||
else size_oblivious,
|
||||
static_expr,
|
||||
)
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user