Compare commits

...

3 Commits

Author SHA1 Message Date
e4de72ea5d Experiment for cold start - fake tensor 2025-09-18 11:20:55 -07:00
d9258fb366 [functionalize] Avoid one more call to custom get_device on FunctionalTensorWrapper
ghstack-source-id: 801aa346f3a2519296f325c6a4b69c09cb484b95
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163019
2025-09-16 09:53:02 -07:00
ef78f99412 [functional] Use the saved device on storage instead for device_custom
ghstack-source-id: a2f54f448ccd8eb4c10f12243ccc8ecf98ae6036
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162987
2025-09-16 09:53:02 -07:00
4 changed files with 132 additions and 19 deletions

View File

@ -133,7 +133,7 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const
: c10::TensorImpl(
c10::DispatchKeySet(DispatchKey::Functionalize),
view_value.dtype(),
view_value.device()
base->storage().data_ptr().device()
),
value_(view_value),
is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output),
@ -485,7 +485,10 @@ void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorI
c10::Device FunctionalTensorWrapper::device_custom() const {
return value_.unsafeGetTensorImpl()->device();
// The storage pointer already uses the underlying tensor custom device (if
// applicable) to extract the device. So, we dont have to recurse again by
// doing value_.unsafeGetTensorImpl()->device().
return storage().data_ptr().device();
}
at::IntArrayRef FunctionalTensorWrapper::sizes_custom() const {
return value_.unsafeGetTensorImpl()->sizes();

16
examples/cold.py Normal file
View File

@ -0,0 +1,16 @@
import torch
@torch.compile(backend="eager")
def fn(x, y, z):
for _ in range(100):
# x = torch.nn.functional.silu(x)
x = torch.addmm(x, y, z)
return x
# return torch.sin(torch.cos(x))
x = torch.randn(20, 20)
y = torch.randn(20, 20)
z = torch.randn(20, 20)
fn(x, y, z)

View File

@ -1452,24 +1452,24 @@ def tensor_split_tensor_indices_or_sections_py_impl(
# TODO: this doesn't appear to have enough precision in bfloat16
@register_decomposition(aten.addmm)
@out_wrapper(exact_dtype=True)
@pw_cast_for_opmath
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
if not self.is_floating_point() and not self.is_complex():
beta = int(beta)
alpha = int(alpha)
out = alpha * torch.mm(mat1, mat2)
if beta == 0:
return out
# @register_decomposition(aten.addmm)
# @out_wrapper(exact_dtype=True)
# @pw_cast_for_opmath
# def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
# if not self.is_floating_point() and not self.is_complex():
# beta = int(beta)
# alpha = int(alpha)
# out = alpha * torch.mm(mat1, mat2)
# if beta == 0:
# return out
# The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
# The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
# We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
# This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
# Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
# This implementation is not ideal, and we should revisit this when we have a better solution.
return out + beta * self
# # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
# # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
# # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
# # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
# # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
# # This implementation is not ideal, and we should revisit this when we have a better solution.
# return out + beta * self
@register_decomposition(aten._addmm_activation)

View File

@ -2026,7 +2026,101 @@ class OutputGraph(OutputGraphGuardsState):
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
self.tracing_context.fake_mode = backend_fake_mode
def pretty_print_counts(d):
# sort by value desc, then by key asc for stable, readable output
items = sorted(d.items(), key=lambda kv: (-kv[1], kv[0]))
width = max((len(k) for k, _ in items), default=0)
for k, v in items:
print(f"{k:<{width}} : {v}")
def profile(name, mod, inputs):
with (
torch._C.DisableTorchFunctionSubclass(),
torch._C.DisableTorchFunction(),
):
mod(*inputs)
mod(*inputs)
mod(*inputs)
import time
t0 = time.perf_counter()
for _ in range(10):
mod(*inputs)
t1 = time.perf_counter()
print(f"{name:<50}{round((t1 - t0) * 1000, 3)} ms")
with self.restore_global_state():
counts = collections.defaultdict(lambda: 0)
for node in gm.graph.nodes:
if "call" in node.op:
counts[str(node.target)] += 1
print("--- Dynamo Fx graph --")
pretty_print_counts(counts)
profile("Base Time", gm, self.example_inputs())
# Meta prop
meta_inputs = []
for inp in self.example_inputs():
meta_inputs.append(inp.to(device="meta"))
profile("Meta Time", gm, meta_inputs)
# Fake prop
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
in_kernel_invocation_manager,
)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_inputs = []
for inp in self.example_inputs():
fake_inputs.append(fake_mode.from_tensor(inp))
with fake_mode:
profile("Fake Time", gm, fake_inputs)
with torch._subclasses.fake_tensor.disable_fake_tensor_cache(
fake_mode
):
profile("Fake w/o caching Time", gm, fake_inputs)
# Patch __torch_dispatch__ to return a fake tensor very quickly.
# This makes __torch_dispatch__ free, and shows the upper limit,
# with using Fake Tensors because of python.
x = self.example_inputs()[0]
def get_fake_output():
empty = torch.empty_strided(
x.size(),
x.stride(),
dtype=x.dtype,
device="meta",
requires_grad=x.requires_grad,
)
return FakeTensor(fake_mode, empty, x.device)
def create_fake_tensor(self, *args, **kwargs):
with in_kernel_invocation_manager(self):
return get_fake_output()
FakeTensorMode.__torch_dispatch__ = create_fake_tensor
with fake_mode:
profile("Patched FakeTensorMode TD Time", gm, fake_inputs)
y = get_fake_output()
def return_fake_tensor(self, *args, **kwargs):
return y
FakeTensorMode.__torch_dispatch__ = return_fake_tensor
with fake_mode:
profile(
"Free FakeTensorMode TD Time", gm, fake_inputs
)
compiled_fn = self.call_user_compiler(gm, self.example_inputs())
from torch.fx._lazy_graph_module import _LazyGraphModule