Re-enable FakeTensor caching for SymInts (#152662)

Summary:

This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present.

There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously.

Test Plan: Reran the tests listed in T196779132 and they pass.

## Perf
### Instruction Counter Benchmark:
- 26% win on add_loop_eager_dynamic
- 13% win on add_loop_inductor_dynamic_gpu
### Perf Dashboard
Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s.

Differential Revision: [D75467694](https://our.internmc.facebook.com/intern/diff/D75467694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662
Approved by: https://github.com/anijain2305
This commit is contained in:
Aaron Orenstein
2025-05-28 20:11:30 -07:00
committed by PyTorch MergeBot
parent 3027051590
commit fc0135ca11
9 changed files with 72 additions and 53 deletions

View File

@ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
opt_device_type = at::getAccelerator(false);
}
if (opt_device_type.has_value()) {
return at::globalContext().getPinnedMemoryAllocator(
opt_device_type.value());
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
} else {
TORCH_CHECK(
false, "Need to provide pin_memory allocator to use pin memory.")

View File

@ -138,7 +138,7 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
// storageOffset
TORCH_CHECK(
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset);
// set_storage_{device} (except set_storage_meta__symint)
// will (unsafely) set the storage offset and then call resize_impl that

View File

@ -431,7 +431,7 @@ Tensor& set_storage_meta__symint(
size, stride, storage_offset);
// Matches maybe_resize_storage_cpu no-numel behavior
if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) {
if (TORCH_GUARD_OR_TRUE(result.sym_numel().sym_ne(0))) {
// maybe_resize_storage_cpu can handle no storage exists at all but
// that should never be the case here
TORCH_INTERNAL_ASSERT(storage);
@ -440,12 +440,7 @@ Tensor& set_storage_meta__symint(
// All meta data pointers are the same, so we don't have to "re" allocate
// it. TODO: Actually this might not quite be correct if we use special
// pointers to track whether or not fake cuda tensors are pinned or not
const auto itemsize = result.dtype().itemsize();
c10::SymInt new_size_bytes = result.is_contiguous()
? at::detail::computeStorageNbytesContiguous(
size, itemsize, std::move(storage_offset))
: at::detail::computeStorageNbytes(
size, stride, itemsize, std::move(storage_offset));
// TODO: When there are unbacked SymInts, we unconditionally skip the
// setter. This is technically wrong, but we cannot conveniently test
// the real condition in many cases, because a lot of people are using
@ -454,12 +449,22 @@ Tensor& set_storage_meta__symint(
//
// The old behavior was to unconditionally set_nbytes, but I think not
// setting it is more safe.
if (result.sym_numel().has_hint()) {
const auto itemsize = result.dtype().itemsize();
c10::SymInt new_size_bytes = result.is_contiguous()
? at::detail::computeStorageNbytesContiguous(
size, itemsize, std::move(storage_offset))
: at::detail::computeStorageNbytes(
size, stride, itemsize, std::move(storage_offset));
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
TORCH_GUARD_SIZE_OBLIVIOUS(
new_size_bytes.sym_gt(storage.sym_nbytes()))) {
storage.set_nbytes(std::move(new_size_bytes));
}
}
}
return result;
}

View File

@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38747844521,0.025

1 add_loop_eager compile_time_instruction_count 2953000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5738000000 4300194436 0.025
3 add_loop_inductor compile_time_instruction_count 29370000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44490000000 38747844521 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 939900000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18270000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16310000000 0.015
10 update_hint_regression compile_time_instruction_count 1700000000 0.02
11 float_args compile_time_instruction_count 452500000 0.015
12 sum_floordiv_regression compile_time_instruction_count 998600000 0.015
13 symint_sum compile_time_instruction_count 3252000000 0.015
14 symint_sum_loop compile_time_instruction_count 4262000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2112000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 6022000000 0.015

View File

@ -1113,9 +1113,11 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
elif isinstance(result, (tuple, list)):
# Preserve the original type (tuple or list)
wrapped = [
(
cls(x, quant_type=quant_type)
if isinstance(x, torch.Tensor)
else x
)
for x in result
]
return type(result)(wrapped)
@ -2535,9 +2537,9 @@ class GraphModule(torch.nn.Module):
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
""", # noqa: B950
)
@ -2591,9 +2593,9 @@ class GraphModule(torch.nn.Module):
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
""", # noqa: B950
)

View File

@ -417,11 +417,12 @@ class TestDynamismExpression(TestCase):
inputs = (torch.arange(10), torch.tensor(2))
# Without transforming the unbacked int expression, we can't export.
with self.assertRaisesRegex(
RuntimeError, escape("Could not guard on data-dependent expression")
):
export(Module(identity), inputs, strict=True)
# See https://github.com/pytorch/pytorch/issues/154574
# # Without transforming the unbacked int expression, we can't export.
# with self.assertRaisesRegex(
# RuntimeError, escape("Could not guard on data-dependent expression")
# ):
# export(Module(identity), inputs, strict=True)
# It works if we transform the whole unbacked int expression into
# an unbacked int.

View File

@ -167,6 +167,7 @@ class _FakeTensorModeSerializer:
def __init__(self, fake_mode: FakeTensorMode) -> None:
self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
self.shape_env = fake_mode.shape_env
@contextlib.contextmanager
def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]:
@ -247,6 +248,7 @@ class _WireProtocolOutput:
metrics: CachedMetricsDeltas
logs: list[logging.LogRecord]
warning_replay: Optional[list[warnings.WarningMessage]]
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]
def serialize(self) -> _WireProtocolPickledOutput:
"""
@ -546,7 +548,11 @@ class _SerializedFxCompile(FxCompile):
logs = captured_logs.finish()
return _WireProtocolOutput(
output_graph, metrics.get_deltas(), logs, warning_replay
output_graph,
metrics.get_deltas(),
logs,
warning_replay,
fake_mode.shape_env,
).serialize()

View File

@ -1521,6 +1521,10 @@ class FakeTensorMode(TorchDispatchMode):
# where it wasn't seen on a previous instance of the same op.
self.shape_env.settings if self.shape_env else None,
]
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
@ -1632,10 +1636,6 @@ class FakeTensorMode(TorchDispatchMode):
raise _BypassDispatchCache("constant attribute")
if is_sparse_any(arg):
raise _BypassDispatchCache(f"{arg.layout} tensor")
# FIXME: For now back out caching when there are symbolic nbytes
# - this doesn't seem to play nice with set(). See T196779132 for examples.
if isinstance(arg.untyped_storage().nbytes(), SymInt):
raise _BypassDispatchCache("symbolic nbytes")
metadata = extract_tensor_metadata(arg)
metadata._flatten_into(result, self, state)
elif isinstance(arg, Tensor):
@ -1764,9 +1764,18 @@ class FakeTensorMode(TorchDispatchMode):
entry_for_synth_output = _DispatchCacheValidEntry(
output_infos=(entry,), is_output_tuple=False
)
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
try:
synth_output = self._output_from_cache_entry(
state, entry_for_synth_output, key, func, args
)
except GuardOnDataDependentSymNode:
# This should probably never really happen. If it does it means that
# although the original call didn't get a data-dependent error when
# we tried to reconstruct the output we did - that's almost
# certainly a bug.
raise _BypassDispatchCache("data dependent symnode") from None
# Make sure the dispatch_key_set from the synthesized output tensor will
# be the same.
@ -1937,11 +1946,7 @@ class FakeTensorMode(TorchDispatchMode):
if entry.is_output_tuple:
outputs = [
self._get_output_tensor_from_cache_entry(
state,
output_info,
key,
func,
args,
state, output_info, key, func, args
)
for output_info in entry.output_infos
]
@ -1974,8 +1979,8 @@ class FakeTensorMode(TorchDispatchMode):
assert isinstance(b, int) and a == b
elif a is None:
assert b is None
elif isinstance(a, torch.SymInt):
assert a is b
elif isinstance(a, py_sym_types):
assert type(a) == type(b) and a.node is b.node
elif isinstance(a, torch.Tensor):
assert isinstance(b, torch.Tensor)
assert_metadata_eq(assert_eq, a, b)

View File

@ -247,7 +247,7 @@ class SymIntEqByExpr:
def _nested_int_aware_sort(
tup: tuple[IntLikeType, int]
tup: tuple[IntLikeType, int],
) -> tuple[int, IntLikeType, int]:
return (
# Order nested ints by their coefficients.
@ -1500,7 +1500,7 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
def guard_scalar(
a: Union[SymBool, SymInt, SymFloat, int, bool, float]
a: Union[SymBool, SymInt, SymFloat, int, bool, float],
) -> Union[bool, int, float]:
"""
Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float.
@ -2178,7 +2178,7 @@ class TrackedFake:
def is_symbolic(
val: Union[int, SymInt, float, SymFloat, bool, SymBool]
val: Union[int, SymInt, float, SymFloat, bool, SymBool],
) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]:
if isinstance(val, (int, float, bool)):
return False
@ -2457,7 +2457,7 @@ def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr:
def cast_symbool_to_symint_guardless(
symbool: Union[bool, torch.SymBool]
symbool: Union[bool, torch.SymBool],
) -> Union[int, torch.SymInt]:
"""
Converts a SymBool or bool to a SymInt or int without introducing guards.
@ -7271,8 +7271,6 @@ class ShapeEnv:
},
)
@lru_cache(256)
@record_shapeenv_event(save_tracked_fakes=True)
def evaluate_expr(
self,
orig_expr: sympy.Basic,
@ -7362,7 +7360,8 @@ class ShapeEnv:
):
return orig_expr
# Don't track this one
# Don't track this one. (Because this cache is inside this function the
# cache only lasts for the invocation of this function call)
@functools.lru_cache(None)
def compute_concrete_val() -> sympy.Basic:
if hint is None:
@ -7443,9 +7442,11 @@ class ShapeEnv:
if static_expr is not None:
self.log.debug(
"eval %s == %s [statically known]",
(
f"size_oblivious({orig_expr})"
if size_oblivious
else size_oblivious,
else size_oblivious
),
static_expr,
)
if (