mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fastpath FunctionalTensor sizes() (#132084)
Another attempt at fast-pathing sizes() in FunctionalTensor, since it appears to improve compile time perf by up to ~10%. See the investigation from https://github.com/pytorch/pytorch/issues/125977#issuecomment-2122915602. After looking at some failing tests locally I realized that we need to manually handle metadata mutations now, since the previous "smarter" size dispatch was handling the updates Pull Request resolved: https://github.com/pytorch/pytorch/pull/132084 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c8958f8f84
commit
997f64af38
@ -3429,8 +3429,8 @@ Tensor unsqueeze_quantized(const Tensor& self, int64_t dim) {
|
||||
Tensor & unsqueeze_(Tensor& self, int64_t dim) {
|
||||
dim = maybe_wrap_dim(dim, self.dim() + 1);
|
||||
|
||||
auto g = inferUnsqueezeGeometry(self, dim);
|
||||
self.as_strided_(g.sizes, g.strides);
|
||||
auto g = inferUnsqueezeGeometry_symint(self, dim);
|
||||
self.as_strided__symint(g.sizes, g.strides);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
@ -147,7 +147,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
elem.device, # device
|
||||
False, # pin_memory
|
||||
elem.requires_grad, # requires_grad
|
||||
"sizes", # dispatch_sizes_strides_policy
|
||||
None, # dispatch_sizes_strides_policy
|
||||
False, # dispatch_device
|
||||
False, # dispatch_layout
|
||||
extra_dispatch_keys, # _extra_dispatch_keys
|
||||
@ -510,6 +510,13 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
or func == torch.ops.aten.lift_fresh.default
|
||||
):
|
||||
return outs_wrapped
|
||||
# for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper
|
||||
if (
|
||||
torch.Tag.inplace_view in func.tags
|
||||
and func is not torch.ops.aten.set_.source_Tensor
|
||||
):
|
||||
with torch.utils._mode_utils.no_dispatch():
|
||||
func(*args, **kwargs)
|
||||
# Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing.
|
||||
# inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects.
|
||||
# Use this util to figure out the right thing to return.
|
||||
|
Reference in New Issue
Block a user