mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re-enable FakeTensor caching for SymInts (#152662)
Summary: This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present. There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously. Test Plan: Reran the tests listed in T196779132 and they pass. ## Perf ### Instruction Counter Benchmark: - 26% win on add_loop_eager_dynamic - 13% win on add_loop_inductor_dynamic_gpu ### Perf Dashboard Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s. Differential Revision: [D75467694](https://our.internmc.facebook.com/intern/diff/D75467694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
3027051590
commit
fc0135ca11
@ -28,8 +28,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
|
||||
opt_device_type = at::getAccelerator(false);
|
||||
}
|
||||
if (opt_device_type.has_value()) {
|
||||
return at::globalContext().getPinnedMemoryAllocator(
|
||||
opt_device_type.value());
|
||||
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Need to provide pin_memory allocator to use pin memory.")
|
||||
|
||||
@ -138,7 +138,7 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset,
|
||||
|
||||
// storageOffset
|
||||
TORCH_CHECK(
|
||||
storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset);
|
||||
TORCH_GUARD_OR_TRUE(sym_ge(storage_offset, 0)), "Tensor: invalid storage offset ", storage_offset);
|
||||
|
||||
// set_storage_{device} (except set_storage_meta__symint)
|
||||
// will (unsafely) set the storage offset and then call resize_impl that
|
||||
|
||||
@ -431,7 +431,7 @@ Tensor& set_storage_meta__symint(
|
||||
size, stride, storage_offset);
|
||||
|
||||
// Matches maybe_resize_storage_cpu no-numel behavior
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) {
|
||||
if (TORCH_GUARD_OR_TRUE(result.sym_numel().sym_ne(0))) {
|
||||
// maybe_resize_storage_cpu can handle no storage exists at all but
|
||||
// that should never be the case here
|
||||
TORCH_INTERNAL_ASSERT(storage);
|
||||
@ -440,12 +440,7 @@ Tensor& set_storage_meta__symint(
|
||||
// All meta data pointers are the same, so we don't have to "re" allocate
|
||||
// it. TODO: Actually this might not quite be correct if we use special
|
||||
// pointers to track whether or not fake cuda tensors are pinned or not
|
||||
const auto itemsize = result.dtype().itemsize();
|
||||
c10::SymInt new_size_bytes = result.is_contiguous()
|
||||
? at::detail::computeStorageNbytesContiguous(
|
||||
size, itemsize, std::move(storage_offset))
|
||||
: at::detail::computeStorageNbytes(
|
||||
size, stride, itemsize, std::move(storage_offset));
|
||||
|
||||
// TODO: When there are unbacked SymInts, we unconditionally skip the
|
||||
// setter. This is technically wrong, but we cannot conveniently test
|
||||
// the real condition in many cases, because a lot of people are using
|
||||
@ -454,12 +449,22 @@ Tensor& set_storage_meta__symint(
|
||||
//
|
||||
// The old behavior was to unconditionally set_nbytes, but I think not
|
||||
// setting it is more safe.
|
||||
if (result.sym_numel().has_hint()) {
|
||||
const auto itemsize = result.dtype().itemsize();
|
||||
|
||||
c10::SymInt new_size_bytes = result.is_contiguous()
|
||||
? at::detail::computeStorageNbytesContiguous(
|
||||
size, itemsize, std::move(storage_offset))
|
||||
: at::detail::computeStorageNbytes(
|
||||
size, stride, itemsize, std::move(storage_offset));
|
||||
|
||||
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() &&
|
||||
TORCH_GUARD_SIZE_OBLIVIOUS(
|
||||
new_size_bytes.sym_gt(storage.sym_nbytes()))) {
|
||||
storage.set_nbytes(std::move(new_size_bytes));
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ add_loop_eager,compile_time_instruction_count,2953000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5738000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44490000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38747844521,0.025
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1113,9 +1113,11 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
elif isinstance(result, (tuple, list)):
|
||||
# Preserve the original type (tuple or list)
|
||||
wrapped = [
|
||||
(
|
||||
cls(x, quant_type=quant_type)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
)
|
||||
for x in result
|
||||
]
|
||||
return type(result)(wrapped)
|
||||
@ -2535,9 +2537,9 @@ class GraphModule(torch.nn.Module):
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -2591,9 +2593,9 @@ class GraphModule(torch.nn.Module):
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_numel_default: "Sym(3*s16)" = torch.ops.aten.sym_numel.default(clone)
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_numel_default, clone_1, primals_5)
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
@ -417,11 +417,12 @@ class TestDynamismExpression(TestCase):
|
||||
|
||||
inputs = (torch.arange(10), torch.tensor(2))
|
||||
|
||||
# Without transforming the unbacked int expression, we can't export.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, escape("Could not guard on data-dependent expression")
|
||||
):
|
||||
export(Module(identity), inputs, strict=True)
|
||||
# See https://github.com/pytorch/pytorch/issues/154574
|
||||
# # Without transforming the unbacked int expression, we can't export.
|
||||
# with self.assertRaisesRegex(
|
||||
# RuntimeError, escape("Could not guard on data-dependent expression")
|
||||
# ):
|
||||
# export(Module(identity), inputs, strict=True)
|
||||
|
||||
# It works if we transform the whole unbacked int expression into
|
||||
# an unbacked int.
|
||||
|
||||
@ -167,6 +167,7 @@ class _FakeTensorModeSerializer:
|
||||
|
||||
def __init__(self, fake_mode: FakeTensorMode) -> None:
|
||||
self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs
|
||||
self.shape_env = fake_mode.shape_env
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]:
|
||||
@ -247,6 +248,7 @@ class _WireProtocolOutput:
|
||||
metrics: CachedMetricsDeltas
|
||||
logs: list[logging.LogRecord]
|
||||
warning_replay: Optional[list[warnings.WarningMessage]]
|
||||
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv]
|
||||
|
||||
def serialize(self) -> _WireProtocolPickledOutput:
|
||||
"""
|
||||
@ -546,7 +548,11 @@ class _SerializedFxCompile(FxCompile):
|
||||
logs = captured_logs.finish()
|
||||
|
||||
return _WireProtocolOutput(
|
||||
output_graph, metrics.get_deltas(), logs, warning_replay
|
||||
output_graph,
|
||||
metrics.get_deltas(),
|
||||
logs,
|
||||
warning_replay,
|
||||
fake_mode.shape_env,
|
||||
).serialize()
|
||||
|
||||
|
||||
|
||||
@ -1521,6 +1521,10 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# where it wasn't seen on a previous instance of the same op.
|
||||
self.shape_env.settings if self.shape_env else None,
|
||||
]
|
||||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
@ -1632,10 +1636,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
raise _BypassDispatchCache("constant attribute")
|
||||
if is_sparse_any(arg):
|
||||
raise _BypassDispatchCache(f"{arg.layout} tensor")
|
||||
# FIXME: For now back out caching when there are symbolic nbytes
|
||||
# - this doesn't seem to play nice with set(). See T196779132 for examples.
|
||||
if isinstance(arg.untyped_storage().nbytes(), SymInt):
|
||||
raise _BypassDispatchCache("symbolic nbytes")
|
||||
metadata = extract_tensor_metadata(arg)
|
||||
metadata._flatten_into(result, self, state)
|
||||
elif isinstance(arg, Tensor):
|
||||
@ -1764,9 +1764,18 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
entry_for_synth_output = _DispatchCacheValidEntry(
|
||||
output_infos=(entry,), is_output_tuple=False
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
|
||||
try:
|
||||
synth_output = self._output_from_cache_entry(
|
||||
state, entry_for_synth_output, key, func, args
|
||||
)
|
||||
except GuardOnDataDependentSymNode:
|
||||
# This should probably never really happen. If it does it means that
|
||||
# although the original call didn't get a data-dependent error when
|
||||
# we tried to reconstruct the output we did - that's almost
|
||||
# certainly a bug.
|
||||
raise _BypassDispatchCache("data dependent symnode") from None
|
||||
|
||||
# Make sure the dispatch_key_set from the synthesized output tensor will
|
||||
# be the same.
|
||||
@ -1937,11 +1946,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if entry.is_output_tuple:
|
||||
outputs = [
|
||||
self._get_output_tensor_from_cache_entry(
|
||||
state,
|
||||
output_info,
|
||||
key,
|
||||
func,
|
||||
args,
|
||||
state, output_info, key, func, args
|
||||
)
|
||||
for output_info in entry.output_infos
|
||||
]
|
||||
@ -1974,8 +1979,8 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
assert isinstance(b, int) and a == b
|
||||
elif a is None:
|
||||
assert b is None
|
||||
elif isinstance(a, torch.SymInt):
|
||||
assert a is b
|
||||
elif isinstance(a, py_sym_types):
|
||||
assert type(a) == type(b) and a.node is b.node
|
||||
elif isinstance(a, torch.Tensor):
|
||||
assert isinstance(b, torch.Tensor)
|
||||
assert_metadata_eq(assert_eq, a, b)
|
||||
|
||||
@ -247,7 +247,7 @@ class SymIntEqByExpr:
|
||||
|
||||
|
||||
def _nested_int_aware_sort(
|
||||
tup: tuple[IntLikeType, int]
|
||||
tup: tuple[IntLikeType, int],
|
||||
) -> tuple[int, IntLikeType, int]:
|
||||
return (
|
||||
# Order nested ints by their coefficients.
|
||||
@ -1500,7 +1500,7 @@ def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType:
|
||||
|
||||
|
||||
def guard_scalar(
|
||||
a: Union[SymBool, SymInt, SymFloat, int, bool, float]
|
||||
a: Union[SymBool, SymInt, SymFloat, int, bool, float],
|
||||
) -> Union[bool, int, float]:
|
||||
"""
|
||||
Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float.
|
||||
@ -2178,7 +2178,7 @@ class TrackedFake:
|
||||
|
||||
|
||||
def is_symbolic(
|
||||
val: Union[int, SymInt, float, SymFloat, bool, SymBool]
|
||||
val: Union[int, SymInt, float, SymFloat, bool, SymBool],
|
||||
) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]:
|
||||
if isinstance(val, (int, float, bool)):
|
||||
return False
|
||||
@ -2457,7 +2457,7 @@ def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr:
|
||||
|
||||
|
||||
def cast_symbool_to_symint_guardless(
|
||||
symbool: Union[bool, torch.SymBool]
|
||||
symbool: Union[bool, torch.SymBool],
|
||||
) -> Union[int, torch.SymInt]:
|
||||
"""
|
||||
Converts a SymBool or bool to a SymInt or int without introducing guards.
|
||||
@ -7271,8 +7271,6 @@ class ShapeEnv:
|
||||
},
|
||||
)
|
||||
|
||||
@lru_cache(256)
|
||||
@record_shapeenv_event(save_tracked_fakes=True)
|
||||
def evaluate_expr(
|
||||
self,
|
||||
orig_expr: sympy.Basic,
|
||||
@ -7362,7 +7360,8 @@ class ShapeEnv:
|
||||
):
|
||||
return orig_expr
|
||||
|
||||
# Don't track this one
|
||||
# Don't track this one. (Because this cache is inside this function the
|
||||
# cache only lasts for the invocation of this function call)
|
||||
@functools.lru_cache(None)
|
||||
def compute_concrete_val() -> sympy.Basic:
|
||||
if hint is None:
|
||||
@ -7443,9 +7442,11 @@ class ShapeEnv:
|
||||
if static_expr is not None:
|
||||
self.log.debug(
|
||||
"eval %s == %s [statically known]",
|
||||
(
|
||||
f"size_oblivious({orig_expr})"
|
||||
if size_oblivious
|
||||
else size_oblivious,
|
||||
else size_oblivious
|
||||
),
|
||||
static_expr,
|
||||
)
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user