fastpath FunctionalTensor sizes() (#132084)

Another attempt at fast-pathing sizes() in FunctionalTensor, since it appears to improve compile time perf by up to ~10%. See the investigation from https://github.com/pytorch/pytorch/issues/125977#issuecomment-2122915602.

After looking at some failing tests locally I realized that we need to manually handle metadata mutations now, since the previous "smarter" size dispatch was handling the updates

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132084
Approved by: https://github.com/ezyang
This commit is contained in:
Brian Hirsh
2024-08-01 17:39:58 +00:00
committed by PyTorch MergeBot
parent c8958f8f84
commit 997f64af38
2 changed files with 10 additions and 3 deletions

View File

@ -3429,8 +3429,8 @@ Tensor unsqueeze_quantized(const Tensor& self, int64_t dim) {
Tensor & unsqueeze_(Tensor& self, int64_t dim) {
dim = maybe_wrap_dim(dim, self.dim() + 1);
auto g = inferUnsqueezeGeometry(self, dim);
self.as_strided_(g.sizes, g.strides);
auto g = inferUnsqueezeGeometry_symint(self, dim);
self.as_strided__symint(g.sizes, g.strides);
return self;
}

View File

@ -147,7 +147,7 @@ class FunctionalTensor(torch.Tensor):
elem.device, # device
False, # pin_memory
elem.requires_grad, # requires_grad
"sizes", # dispatch_sizes_strides_policy
None, # dispatch_sizes_strides_policy
False, # dispatch_device
False, # dispatch_layout
extra_dispatch_keys, # _extra_dispatch_keys
@ -510,6 +510,13 @@ class FunctionalTensorMode(TorchDispatchMode):
or func == torch.ops.aten.lift_fresh.default
):
return outs_wrapped
# for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper
if (
torch.Tag.inplace_view in func.tags
and func is not torch.ops.aten.set_.source_Tensor
):
with torch.utils._mode_utils.no_dispatch():
func(*args, **kwargs)
# Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing.
# inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects.
# Use this util to figure out the right thing to return.