mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #123698 This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors. 1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True` 2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts. 3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT. Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt. Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687 Approved by: https://github.com/albanD