mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[do not revert] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#155590)
When we compute contiguity for a tensor with dynamic shapes we first: 1) Try to compute it without guarding. 2) If all shapes hinted, compute it with potentially adding guards. 3) if any input is not hinted, compute it symbolically. sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called on it to avoid data dependent errors. ex: bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__); is_contiguous_or_false is a helper function that does that. In this PR I only handle default contiguity, will follow up with changes for other formats like channel_last . We use this patter in this PR for several locations to avoid DDEs. Differential Revision: [D77183032](https://our.internmc.facebook.com/intern/diff/D77183032) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155590 Approved by: https://github.com/ezyang
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							22edb457c9
						
					
				
				
					commit
					d0a9629435
				
			
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | |||||||
| 55a75404c9b75cd5fd62ab5d4deafc8c506b3af2 | 926700d7832caa552ba2e1fc8302f6a2f4d2f6d8 | ||||||
|  | |||||||
| @ -499,8 +499,8 @@ int64_t FunctionalTensorWrapper::dim_custom() const { | |||||||
| int64_t FunctionalTensorWrapper::numel_custom() const { | int64_t FunctionalTensorWrapper::numel_custom() const { | ||||||
|   return value_.unsafeGetTensorImpl()->numel(); |   return value_.unsafeGetTensorImpl()->numel(); | ||||||
| } | } | ||||||
| bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const { | c10::SymBool FunctionalTensorWrapper::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { | ||||||
|   return value_.unsafeGetTensorImpl()->is_contiguous(memory_format); |   return value_.unsafeGetTensorImpl()->sym_is_contiguous(memory_format); | ||||||
| } | } | ||||||
| c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { | c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const { | ||||||
|   return value_.unsafeGetTensorImpl()->sym_sizes(); |   return value_.unsafeGetTensorImpl()->sym_sizes(); | ||||||
|  | |||||||
| @ -236,7 +236,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | |||||||
|   at::IntArrayRef strides_custom() const override; |   at::IntArrayRef strides_custom() const override; | ||||||
|   int64_t dim_custom() const override; |   int64_t dim_custom() const override; | ||||||
|   int64_t numel_custom() const override; |   int64_t numel_custom() const override; | ||||||
|   bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |   c10::SymBool sym_is_contiguous_custom( | ||||||
|  |       at::MemoryFormat memory_format) const override; | ||||||
|   c10::SymIntArrayRef sym_sizes_custom() const override; |   c10::SymIntArrayRef sym_sizes_custom() const override; | ||||||
|   c10::SymInt sym_size_custom(int64_t d) const override; |   c10::SymInt sym_size_custom(int64_t d) const override; | ||||||
|   c10::SymIntArrayRef sym_strides_custom() const override; |   c10::SymIntArrayRef sym_strides_custom() const override; | ||||||
|  | |||||||
| @ -320,11 +320,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt | |||||||
|   auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); |   auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), inferred_size); | ||||||
|  |  | ||||||
|   if (!stride.has_value()) { |   if (!stride.has_value()) { | ||||||
|     // With unbacked symints, computeStride could fail even on contiguous |  | ||||||
|     // tensors. In this case, we can use the strides of an empty tensor of |     TORCH_SYM_CHECK( | ||||||
|     // inferred_size. |         self.sym_is_contiguous(), | ||||||
|     TORCH_CHECK( |  | ||||||
|         self.is_contiguous(), |  | ||||||
|         "View is not valid from size:", |         "View is not valid from size:", | ||||||
|         self.sym_sizes(), |         self.sym_sizes(), | ||||||
|         " stride: ", |         " stride: ", | ||||||
| @ -333,6 +331,9 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt | |||||||
|         inferred_size, |         inferred_size, | ||||||
|         " in case of unbacked symbols consider adding torch.check to guide computing strides."); |         " in case of unbacked symbols consider adding torch.check to guide computing strides."); | ||||||
|  |  | ||||||
|  |     // With unbacked symints, computeStride could fail even on contiguous | ||||||
|  |     // tensors. In this case, we can use the strides of an empty tensor of | ||||||
|  |     // inferred_size. | ||||||
|     stride = at::detail::empty_symint_meta( |     stride = at::detail::empty_symint_meta( | ||||||
|                  inferred_size, |                  inferred_size, | ||||||
|                  std::nullopt, |                  std::nullopt, | ||||||
|  | |||||||
| @ -84,7 +84,7 @@ IntArrayRef BatchedTensorImpl::strides_custom() const { | |||||||
|  |  | ||||||
| // TODO: implement proper contiguity on batched tensor, then put | // TODO: implement proper contiguity on batched tensor, then put | ||||||
| // sizes_strides_policy back to Default | // sizes_strides_policy back to Default | ||||||
| bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { | c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { | ||||||
|   TORCH_CHECK(memory_format == MemoryFormat::Contiguous, |   TORCH_CHECK(memory_format == MemoryFormat::Contiguous, | ||||||
|       "NYI: querying is_contiguous inside of vmap for memory_format ", |       "NYI: querying is_contiguous inside of vmap for memory_format ", | ||||||
|       "other than torch.contiguous_format"); |       "other than torch.contiguous_format"); | ||||||
|  | |||||||
| @ -82,7 +82,8 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { | |||||||
|   IntArrayRef strides_custom() const override; |   IntArrayRef strides_custom() const override; | ||||||
|   // Override a bunch of methods inherited from TensorImpl to return error |   // Override a bunch of methods inherited from TensorImpl to return error | ||||||
|   // messages. |   // messages. | ||||||
|   bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |   c10::SymBool sym_is_contiguous_custom( | ||||||
|  |       at::MemoryFormat memory_format) const override; | ||||||
|   void set_size(int64_t dim, int64_t new_size) override; |   void set_size(int64_t dim, int64_t new_size) override; | ||||||
|   void set_stride(int64_t dim, int64_t new_stride) override; |   void set_stride(int64_t dim, int64_t new_stride) override; | ||||||
|   void set_storage_offset(int64_t storage_offset) override; |   void set_storage_offset(int64_t storage_offset) override; | ||||||
|  | |||||||
| @ -24,7 +24,7 @@ MemOverlap has_internal_overlap(TensorImpl* t) { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   if (t->is_non_overlapping_and_dense()) { |   if (t->is_non_overlapping_and_dense_or_false()) { | ||||||
|     return MemOverlap::No; |     return MemOverlap::No; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -63,7 +63,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) { | |||||||
|   if (a->numel() == 0 || b->numel() == 0) { |   if (a->numel() == 0 || b->numel() == 0) { | ||||||
|     return MemOverlapStatus::No; |     return MemOverlapStatus::No; | ||||||
|   } |   } | ||||||
|   if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) { |   if (!a->is_non_overlapping_and_dense_or_false() || !b->is_non_overlapping_and_dense_or_false()) { | ||||||
|     return MemOverlapStatus::TooHard; |     return MemOverlapStatus::TooHard; | ||||||
|   } |   } | ||||||
|   // Test for storage equality, rather than pointer equality. |   // Test for storage equality, rather than pointer equality. | ||||||
|  | |||||||
| @ -273,7 +273,7 @@ c10::SymInt NestedTensorImpl::sym_numel_custom() const { | |||||||
|   return NestedTensorImpl::numel_custom(); |   return NestedTensorImpl::numel_custom(); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const { | c10::SymBool NestedTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { | ||||||
|   return nested_tensor_impl_is_contiguous(this); |   return nested_tensor_impl_is_contiguous(this); | ||||||
| } | } | ||||||
| IntArrayRef NestedTensorImpl::sizes_custom() const { | IntArrayRef NestedTensorImpl::sizes_custom() const { | ||||||
|  | |||||||
| @ -115,7 +115,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { | |||||||
|   // with real implementations |   // with real implementations | ||||||
|   int64_t numel_custom() const override; |   int64_t numel_custom() const override; | ||||||
|   c10::SymInt sym_numel_custom() const override; |   c10::SymInt sym_numel_custom() const override; | ||||||
|   bool is_contiguous_custom(MemoryFormat) const override; |   c10::SymBool sym_is_contiguous_custom(MemoryFormat) const override; | ||||||
|   int64_t size_custom(int64_t d) const override { |   int64_t size_custom(int64_t d) const override { | ||||||
|     return this->size(d); |     return this->size(d); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -252,8 +252,7 @@ void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) { | |||||||
| void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { | void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) { | ||||||
|   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset."); |   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset."); | ||||||
| } | } | ||||||
| bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const { | c10::SymBool SparseCsrTensorImpl::sym_is_contiguous_custom(MemoryFormat) const { | ||||||
|   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); |   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous"); | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace at | } // namespace at | ||||||
|  | |||||||
| @ -86,7 +86,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl { | |||||||
|  protected: |  protected: | ||||||
|   IntArrayRef strides_custom() const override; |   IntArrayRef strides_custom() const override; | ||||||
|   SymIntArrayRef sym_strides_custom() const override; |   SymIntArrayRef sym_strides_custom() const override; | ||||||
|   bool is_contiguous_custom(MemoryFormat) const override; |   SymBool sym_is_contiguous_custom(MemoryFormat) const override; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
|   void set_size(int64_t dim, int64_t new_size) override; |   void set_size(int64_t dim, int64_t new_size) override; | ||||||
|  | |||||||
| @ -124,7 +124,7 @@ class TORCH_API TensorBase { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { |   TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { | ||||||
|     if (is_contiguous(memory_format)) { |     if (is_contiguous_or_false(memory_format)) { | ||||||
|       return *this; |       return *this; | ||||||
|     } else { |     } else { | ||||||
|       return __dispatch_contiguous(memory_format); |       return __dispatch_contiguous(memory_format); | ||||||
| @ -265,6 +265,25 @@ class TORCH_API TensorBase { | |||||||
|     return impl_->is_contiguous(memory_format); |     return impl_->is_contiguous(memory_format); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   // Like is_contiguous, but more dynamic shape-friendly. Maybe returns a symbolic representation of | ||||||
|  |   // contiguity instead of SymTrue SymFalse, when results are data-dependent. | ||||||
|  |   c10::SymBool sym_is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { | ||||||
|  |     if (impl_->has_symbolic_sizes_strides()) { | ||||||
|  |       return impl_->sym_is_contiguous(memory_format); | ||||||
|  |     } | ||||||
|  |     return impl_->is_contiguous(memory_format); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Like is_contiguous, but more dynamic shape-friendly. Can returns | ||||||
|  |   // false instead of throwing data-dependent errors for tensors with unbacked | ||||||
|  |   // sizes or strides. | ||||||
|  |   bool is_contiguous_or_false(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { | ||||||
|  |     if (impl_->has_symbolic_sizes_strides()) { | ||||||
|  |       return impl_->sym_is_contiguous(memory_format).guard_or_false(__FILE__, __LINE__); | ||||||
|  |     } | ||||||
|  |     return impl_->is_contiguous(memory_format); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   bool is_non_overlapping_and_dense() const { |   bool is_non_overlapping_and_dense() const { | ||||||
|     return impl_->is_non_overlapping_and_dense(); |     return impl_->is_non_overlapping_and_dense(); | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -126,7 +126,7 @@ SymIntArrayRef BatchedTensorImpl::sym_strides_custom() const { | |||||||
|  |  | ||||||
| // TODO: implement proper contiguity on batched tensor, then put | // TODO: implement proper contiguity on batched tensor, then put | ||||||
| // sizes_strides_policy back to Default | // sizes_strides_policy back to Default | ||||||
| bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { | c10::SymBool BatchedTensorImpl::sym_is_contiguous_custom(at::MemoryFormat memory_format) const { | ||||||
|   TORCH_CHECK(memory_format == MemoryFormat::Contiguous, |   TORCH_CHECK(memory_format == MemoryFormat::Contiguous, | ||||||
|       "NYI: querying is_contiguous inside of vmap for memory_format ", |       "NYI: querying is_contiguous inside of vmap for memory_format ", | ||||||
|       "other than torch.contiguous_format"); |       "other than torch.contiguous_format"); | ||||||
|  | |||||||
| @ -69,7 +69,7 @@ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { | |||||||
|   IntArrayRef strides_custom() const override; |   IntArrayRef strides_custom() const override; | ||||||
|   SymIntArrayRef sym_strides_custom() const override; |   SymIntArrayRef sym_strides_custom() const override; | ||||||
|   // Override a bunch of methods inherited from TensorImpl to return error messages. |   // Override a bunch of methods inherited from TensorImpl to return error messages. | ||||||
|   bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; |   c10::SymBool sym_is_contiguous_custom(at::MemoryFormat memory_format) const override; | ||||||
|   void set_size(int64_t dim, int64_t new_size) override; |   void set_size(int64_t dim, int64_t new_size) override; | ||||||
|   void set_stride(int64_t dim, int64_t new_stride) override; |   void set_stride(int64_t dim, int64_t new_stride) override; | ||||||
|   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( |   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( | ||||||
|  | |||||||
| @ -93,7 +93,7 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten | |||||||
|   if (bias->defined() && !input.is_xla()) { |   if (bias->defined() && !input.is_xla()) { | ||||||
|     // Also hit the fused path for contiguous 3D input, if not using xla |     // Also hit the fused path for contiguous 3D input, if not using xla | ||||||
|     // backend. Reshaping/flattening has some performance implications on xla. |     // backend. Reshaping/flattening has some performance implications on xla. | ||||||
|     bool is_contiguous = definitely_contiguous(input.sym_sizes(), input.sym_strides(), input.sym_numel()); |     bool is_contiguous = input.is_contiguous_or_false(); | ||||||
|     if (is_contiguous && input_dim == 3) { |     if (is_contiguous && input_dim == 3) { | ||||||
|       return _flatten_nd_linear(input, weight, *bias); |       return _flatten_nd_linear(input, weight, *bias); | ||||||
|     } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { |     } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { | ||||||
|  | |||||||
| @ -113,7 +113,7 @@ Tensor& detach_(Tensor& self) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { | Tensor contiguous(const Tensor& self, MemoryFormat memory_format) { | ||||||
|   if (self.is_contiguous(memory_format)) { |   if (self.is_contiguous_or_false(memory_format)) { | ||||||
|     return self; |     return self; | ||||||
|   } |   } | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|  | |||||||
| @ -1998,19 +1998,18 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { | |||||||
|     TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); |     TORCH_CHECK(false, "reshape is not implemented for sparse tensors"); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto sym_sizes = self.sym_sizes(); |   if (self.is_contiguous_or_false() && !self.is_mkldnn()) { | ||||||
|   auto sym_strides = self.sym_strides(); |  | ||||||
|   auto sym_numel = self.sym_numel(); |  | ||||||
|   if (definitely_contiguous(sym_sizes, sym_strides, sym_numel) && |  | ||||||
|       !self.is_mkldnn()) { |  | ||||||
|     return self.view_symint(proposed_shape); |     return self.view_symint(proposed_shape); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   auto sym_numel = self.sym_numel(); | ||||||
|   c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); |   c10::SymDimVector shape = infer_size_dv(proposed_shape, sym_numel); | ||||||
|  |  | ||||||
|   if (self.is_mkldnn()) { |   if (self.is_mkldnn()) { | ||||||
|     return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); |     return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape)); | ||||||
|   } |   } | ||||||
|  |   auto sym_sizes = self.sym_sizes(); | ||||||
|  |   auto sym_strides = self.sym_strides(); | ||||||
|  |  | ||||||
|   // `computeStride` returns the proper strides to use if this |   // `computeStride` returns the proper strides to use if this | ||||||
|   // `reshape` can be just a view. |   // `reshape` can be just a view. | ||||||
|  | |||||||
| @ -35,7 +35,7 @@ struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> { | |||||||
|     return c10::fromIntArrayRefKnownNonNegative(strides_); |     return c10::fromIntArrayRefKnownNonNegative(strides_); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { |   c10::SymBool sym_is_contiguous_custom(c10::MemoryFormat memory_format) const override { | ||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | |||||||
| @ -776,7 +776,7 @@ Tensor scaled_dot_product_attention( | |||||||
| #ifdef USE_MPS | #ifdef USE_MPS | ||||||
|       const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); |       const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); | ||||||
|       const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); |       const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); | ||||||
|       const auto all_contiguous = query_.is_contiguous() && key.is_contiguous() && value.is_contiguous(); |       const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false(); | ||||||
|       if (query_device_type == DeviceType::MPS && dropout_p == 0.0 |       if (query_device_type == DeviceType::MPS && dropout_p == 0.0 | ||||||
|           && !(GradMode::is_enabled() && any_inputs_require_grad) |           && !(GradMode::is_enabled() && any_inputs_require_grad) | ||||||
|           && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) |           && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) | ||||||
|  | |||||||
| @ -33,7 +33,8 @@ struct VulkanOpaqueTensorImpl : public OpaqueTensorImpl<OpaqueHandle> { | |||||||
|     return c10::fromIntArrayRefKnownNonNegative(strides_); |     return c10::fromIntArrayRefKnownNonNegative(strides_); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   bool is_contiguous_custom(c10::MemoryFormat memory_format) const override { |   c10::SymBool sym_is_contiguous_custom( | ||||||
|  |       c10::MemoryFormat memory_format) const override { | ||||||
|     (void)memory_format; |     (void)memory_format; | ||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ namespace c10 { | |||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { | bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { | ||||||
|   if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) { |   if (numel == 0) { | ||||||
|     return true; |     return true; | ||||||
|   } |   } | ||||||
|  |  | ||||||
| @ -20,11 +20,11 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { | |||||||
|   // NB: make sure we do signed arithmetic |   // NB: make sure we do signed arithmetic | ||||||
|   for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { |   for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { | ||||||
|     const auto& size_d = sizes[d]; |     const auto& size_d = sizes[d]; | ||||||
|     if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(size_d, 1))) { |     if (size_d == 1) { | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected_stride))) { |     if (strides[d] != expected_stride) { | ||||||
|       return false; |       return false; | ||||||
|     } |     } | ||||||
|     expected_stride *= size_d; |     expected_stride *= size_d; | ||||||
| @ -32,29 +32,66 @@ bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { | |||||||
|   return true; |   return true; | ||||||
| } | } | ||||||
|  |  | ||||||
| // This function will return True if the tensor is contiguous, and False if the | // Return a SymBool with underlying symbolic expression that represents | ||||||
| // its not or if we can't determine if it is contiguous due to unbacked symbols | // contiguity. Guaranteed not to add guards. | ||||||
| // (it could be either in that case based on the actual runtime data). | inline static c10::SymBool _compute_contiguous_sym( | ||||||
| template <typename T> |     ArrayRef<c10::SymInt> sizes, | ||||||
| bool definitely_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) { |     ArrayRef<c10::SymInt> strides, | ||||||
|   if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { |     const c10::SymInt& numel) { | ||||||
|  |   // If this return true, the tensor is contiguous indeed. Otherwise it could be | ||||||
|  |   // either. | ||||||
|  |   auto is_contiguous_or_false = [&]() { | ||||||
|  |     if (TORCH_GUARD_OR_FALSE(sym_eq(numel, 0))) { | ||||||
|  |       return true; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // When calculating the expected stride, we can choose to multiply | ||||||
|  |     // with max(1, size[d]) or size[d]. Regardless, this is ok for this | ||||||
|  |     // function. Why? | ||||||
|  |     // (1) If size[d] == 0, then the tensor is contiguous and if | ||||||
|  |     //     we return true or false it won't break this function. | ||||||
|  |     // (2) If size[d] is not 0, then max(1,size[d]) and size[d] are equal. | ||||||
|  |     //     Therefore, if we choose to use max(1, size[d]) or size[d] to | ||||||
|  |     //     calculate the expected stride, the result is the same. | ||||||
|  |     // | ||||||
|  |     // We symbolically check both paths to maximize the cases where this | ||||||
|  |     // function returns true. This is because make_contiguous_strides_for adds | ||||||
|  |     // the max symbolically, and in some other situations the max might not be | ||||||
|  |     // there. And we want to ensure we return true in both cases. | ||||||
|  |     c10::SymInt expected_stride = 1; | ||||||
|  |     c10::SymInt expected_stride_max = 1; | ||||||
|  |     // NB: make sure we do signed arithmetic | ||||||
|  |     for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { | ||||||
|  |       if (TORCH_GUARD_OR_FALSE(sym_eq(sizes[d], 1))) { | ||||||
|  |         continue; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride)) && | ||||||
|  |           TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride_max))) { | ||||||
|  |         return false; | ||||||
|  |       } | ||||||
|  |       expected_stride_max *= sizes[d].max(1); | ||||||
|  |       expected_stride *= sizes[d]; | ||||||
|  |     } | ||||||
|     return true; |     return true; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   if (is_contiguous_or_false()) { | ||||||
|  |     return c10::SymBool(true); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   T expected_stride = 1; |   // Build a single expression that represents contiguity and return it. | ||||||
|   // NB: make sure we do signed arithmetic |   c10::SymBool is_empty = sym_eq(numel, 0); | ||||||
|  |   c10::SymBool is_contiguous_cond = true; | ||||||
|  |  | ||||||
|  |   c10::SymInt expected_stride = 1; | ||||||
|   for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { |   for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) { | ||||||
|     const auto& size_d = sizes[d]; |     const auto& size_d = sizes[d]; | ||||||
|     if (TORCH_GUARD_OR_FALSE(sym_eq(size_d, 1))) { |     is_contiguous_cond = is_contiguous_cond.sym_and( | ||||||
|       continue; |         size_d.sym_eq(1).sym_or(sym_eq(strides[d], expected_stride))); | ||||||
|     } |     expected_stride = expected_stride * size_d; | ||||||
|  |  | ||||||
|     if (TORCH_GUARD_OR_TRUE(sym_ne(strides[d], expected_stride))) { |  | ||||||
|       return false; |  | ||||||
|     } |  | ||||||
|     expected_stride *= size_d; |  | ||||||
|   } |   } | ||||||
|   return true; |   return is_contiguous_cond.sym_or(is_empty); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
|  | |||||||
| @ -79,18 +79,51 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { | |||||||
|   } |   } | ||||||
|   c10::SymIntArrayRef sizes(sizes_); |   c10::SymIntArrayRef sizes(sizes_); | ||||||
|   c10::SymIntArrayRef strides(strides_); |   c10::SymIntArrayRef strides(strides_); | ||||||
|   return _compute_contiguous(sizes, strides, numel()); |  | ||||||
|  |   auto result = _compute_contiguous_sym(sizes, strides, numel()); | ||||||
|  |  | ||||||
|  |   // If the result is already determined without guarding, just return it. | ||||||
|  |   auto maybe_as_bool = result.maybe_as_bool(); | ||||||
|  |   if (maybe_as_bool.has_value()) { | ||||||
|  |     return maybe_as_bool.value(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto all_hinted = true; | ||||||
|  |   for (const auto& s : sizes) { | ||||||
|  |     if (!s.has_hint()) { | ||||||
|  |       all_hinted = false; | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (all_hinted) { | ||||||
|  |     for (const auto& s : strides) { | ||||||
|  |       if (!s.has_hint()) { | ||||||
|  |         all_hinted = false; | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   if (all_hinted) { | ||||||
|  |     // We avoid going through the slow path if everything is hinted, | ||||||
|  |     // because evaluating a large SymPy expression can be expensive. | ||||||
|  |     // TODO exclude backed_size_oblivious from this path. | ||||||
|  |     return _compute_contiguous<SymInt>(sizes_, strides_, numel()); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return result; | ||||||
| } | } | ||||||
|  |  | ||||||
| // The rest of them | // The rest of them | ||||||
| #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \ | #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, fallback) \ | ||||||
|   SymBool SymbolicShapeMeta::name() const {                    \ |   SymBool SymbolicShapeMeta::name() const {          \ | ||||||
|     if (!strides_valid_) {                                     \ |     if (!strides_valid_) {                           \ | ||||||
|       return false;                                            \ |       return false;                                  \ | ||||||
|     }                                                          \ |     }                                                \ | ||||||
|     c10::SymIntArrayRef sizes(sizes_);                         \ |     c10::SymIntArrayRef sizes(sizes_);               \ | ||||||
|     c10::SymIntArrayRef strides(strides_);                     \ |     c10::SymIntArrayRef strides(strides_);           \ | ||||||
|     return fallback(sizes, strides);                           \ |     return fallback(sizes, strides);                 \ | ||||||
|   } |   } | ||||||
|  |  | ||||||
| #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback)        \ | #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback)        \ | ||||||
| @ -110,11 +143,13 @@ SymBool SymbolicShapeMeta::compute_contiguous() const { | |||||||
|   } |   } | ||||||
|  |  | ||||||
| // clang-format off | // clang-format off | ||||||
| DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d) | ||||||
| DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d) | ||||||
| DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d) | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d) | ||||||
| DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d) | DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d) | ||||||
|  |  | ||||||
| DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) | DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense) | ||||||
|  |  | ||||||
| // clang-format on | // clang-format on | ||||||
|  |  | ||||||
| #undef DEFINE_SYMBOOL_COMPUTE | #undef DEFINE_SYMBOOL_COMPUTE | ||||||
| @ -192,6 +227,7 @@ void SymbolicShapeMeta::set_numel(SymInt val) const { | |||||||
|   numel_ = std::move(val); |   numel_ = std::move(val); | ||||||
|   available_.fetch_or(numel_avail); |   available_.fetch_or(numel_avail); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { | void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { | ||||||
|   std::scoped_lock lock(mutables_); |   std::scoped_lock lock(mutables_); | ||||||
|   if (has_is_contiguous()) { |   if (has_is_contiguous()) { | ||||||
| @ -200,6 +236,7 @@ void SymbolicShapeMeta::set_is_contiguous(SymBool val) const { | |||||||
|   is_contiguous_ = std::move(val); |   is_contiguous_ = std::move(val); | ||||||
|   available_.fetch_or(is_contiguous_avail); |   available_.fetch_or(is_contiguous_avail); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { | void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { | ||||||
|   std::scoped_lock lock(mutables_); |   std::scoped_lock lock(mutables_); | ||||||
|   if (has_is_channels_last_contiguous()) { |   if (has_is_channels_last_contiguous()) { | ||||||
| @ -208,6 +245,7 @@ void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const { | |||||||
|   is_channels_last_contiguous_ = std::move(val); |   is_channels_last_contiguous_ = std::move(val); | ||||||
|   available_.fetch_or(is_channels_last_contiguous_avail); |   available_.fetch_or(is_channels_last_contiguous_avail); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { | void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { | ||||||
|   std::scoped_lock lock(mutables_); |   std::scoped_lock lock(mutables_); | ||||||
|   if (has_is_channels_last_3d_contiguous()) { |   if (has_is_channels_last_3d_contiguous()) { | ||||||
| @ -216,6 +254,7 @@ void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const { | |||||||
|   is_channels_last_3d_contiguous_ = std::move(val); |   is_channels_last_3d_contiguous_ = std::move(val); | ||||||
|   available_.fetch_or(is_channels_last_3d_contiguous_avail); |   available_.fetch_or(is_channels_last_3d_contiguous_avail); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { | void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { | ||||||
|   std::scoped_lock lock(mutables_); |   std::scoped_lock lock(mutables_); | ||||||
|   if (has_is_channels_last()) { |   if (has_is_channels_last()) { | ||||||
| @ -224,6 +263,7 @@ void SymbolicShapeMeta::set_is_channels_last(SymBool val) const { | |||||||
|   is_channels_last_ = std::move(val); |   is_channels_last_ = std::move(val); | ||||||
|   available_.fetch_or(is_channels_last_avail); |   available_.fetch_or(is_channels_last_avail); | ||||||
| } | } | ||||||
|  |  | ||||||
| void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { | void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const { | ||||||
|   std::scoped_lock lock(mutables_); |   std::scoped_lock lock(mutables_); | ||||||
|   if (has_is_channels_last_3d()) { |   if (has_is_channels_last_3d()) { | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| #pragma once | #pragma once | ||||||
|  | #include <c10/core/MemoryFormat.h> | ||||||
| #include <c10/core/SymBool.h> | #include <c10/core/SymBool.h> | ||||||
| #include <c10/core/SymInt.h> | #include <c10/core/SymInt.h> | ||||||
| #include <c10/macros/Export.h> | #include <c10/macros/Export.h> | ||||||
| @ -82,6 +83,15 @@ class C10_API SymbolicShapeMeta { | |||||||
|     return numel_; |     return numel_; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   const SymBool& is_contiguous(at::MemoryFormat memory_format) const { | ||||||
|  |     if (memory_format == at::MemoryFormat::ChannelsLast) { | ||||||
|  |       return this->is_channels_last_contiguous(); | ||||||
|  |     } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { | ||||||
|  |       return this->is_channels_last_3d_contiguous(); | ||||||
|  |     } | ||||||
|  |     return this->is_contiguous(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   const SymBool& is_contiguous() const { |   const SymBool& is_contiguous() const { | ||||||
|     if (C10_UNLIKELY(!has_is_contiguous())) { |     if (C10_UNLIKELY(!has_is_contiguous())) { | ||||||
|       init_is_contiguous(); |       init_is_contiguous(); | ||||||
| @ -194,6 +204,7 @@ class C10_API SymbolicShapeMeta { | |||||||
|   // Lazily initialized variables, with the corresponding available_ flag |   // Lazily initialized variables, with the corresponding available_ flag | ||||||
|   // indicating whether the value has been initialized |   // indicating whether the value has been initialized | ||||||
|   mutable std::atomic<int> available_{0}; |   mutable std::atomic<int> available_{0}; | ||||||
|  |  | ||||||
|   enum avail { |   enum avail { | ||||||
|     numel_avail = 1 << 0, |     numel_avail = 1 << 0, | ||||||
|     is_contiguous_avail = 1 << 1, |     is_contiguous_avail = 1 << 1, | ||||||
|  | |||||||
| @ -310,12 +310,14 @@ void TensorImpl::throw_data_ptr_access_error() const { | |||||||
|       false, "Cannot access data pointer of Tensor that doesn't have storage"); |       false, "Cannot access data pointer of Tensor that doesn't have storage"); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { | c10::SymBool TensorImpl::sym_is_contiguous_custom( | ||||||
|  |     at::MemoryFormat memory_format) const { | ||||||
|   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { | ||||||
|     return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( |     return pyobj_slot_.load_pyobj_interpreter()->is_contiguous( | ||||||
|         this, memory_format); |         this, memory_format); | ||||||
|   } |   } | ||||||
|   return is_contiguous_default(memory_format); |  | ||||||
|  |   return sym_is_contiguous_default(memory_format); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { | bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { | ||||||
| @ -326,12 +328,12 @@ bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const { | |||||||
|   return is_strides_like_default(memory_format); |   return is_strides_like_default(memory_format); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool TensorImpl::is_non_overlapping_and_dense_custom() const { | c10::SymBool TensorImpl::sym_is_non_overlapping_and_dense_custom() const { | ||||||
|   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { |   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) { | ||||||
|     return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( |     return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense( | ||||||
|         this); |         this); | ||||||
|   } |   } | ||||||
|   return is_non_overlapping_and_dense_default(); |   return sym_is_non_overlapping_and_dense_default(); | ||||||
| } | } | ||||||
|  |  | ||||||
| IntArrayRef TensorImpl::sizes_custom() const { | IntArrayRef TensorImpl::sizes_custom() const { | ||||||
|  | |||||||
| @ -812,6 +812,43 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   c10::SymBool sym_is_contiguous( | ||||||
|  |       at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { | ||||||
|  |     if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { | ||||||
|  |       return sym_is_contiguous_custom(memory_format); | ||||||
|  |     } | ||||||
|  |     return sym_is_contiguous_default(memory_format); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   template <typename T> | ||||||
|  |   T is_contiguous_default_impl(at::MemoryFormat memory_format) const { | ||||||
|  |     if (!has_symbolic_sizes_strides_) { | ||||||
|  |       if (memory_format == at::MemoryFormat::ChannelsLast) { | ||||||
|  |         return is_channels_last_contiguous_; | ||||||
|  |       } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { | ||||||
|  |         return is_channels_last_3d_contiguous_; | ||||||
|  |       } | ||||||
|  |       return is_contiguous_; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Handle dynamic shapes. | ||||||
|  |     const auto& symbolic = symbolic_shape_meta().is_contiguous(memory_format); | ||||||
|  |  | ||||||
|  |     if constexpr (std::is_same_v<T, bool>) { | ||||||
|  |       return symbolic.guard_bool(__FILE__, __LINE__); | ||||||
|  |     } else { | ||||||
|  |       return symbolic; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   bool is_contiguous_default(at::MemoryFormat memory_format) const { | ||||||
|  |     return is_contiguous_default_impl<bool>(memory_format); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   c10::SymBool sym_is_contiguous_default(at::MemoryFormat memory_format) const { | ||||||
|  |     return is_contiguous_default_impl<c10::SymBool>(memory_format); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   /** |   /** | ||||||
|    * Whether or not a tensor is laid out in contiguous memory. |    * Whether or not a tensor is laid out in contiguous memory. | ||||||
|    * |    * | ||||||
| @ -827,30 +864,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|     return is_contiguous_default(memory_format); |     return is_contiguous_default(memory_format); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // These are factored into separate functions in case subclasses |  | ||||||
|   // want to use them |  | ||||||
|   bool is_contiguous_default(at::MemoryFormat memory_format) const { |  | ||||||
|     if (has_symbolic_sizes_strides_) { |  | ||||||
|       if (memory_format == at::MemoryFormat::ChannelsLast) { |  | ||||||
|         return symbolic_shape_meta().is_channels_last_contiguous().guard_bool( |  | ||||||
|             __FILE__, __LINE__); |  | ||||||
|       } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { |  | ||||||
|         return symbolic_shape_meta() |  | ||||||
|             .is_channels_last_3d_contiguous() |  | ||||||
|             .guard_bool(__FILE__, __LINE__); |  | ||||||
|       } |  | ||||||
|       return symbolic_shape_meta().is_contiguous().guard_bool( |  | ||||||
|           __FILE__, __LINE__); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (memory_format == at::MemoryFormat::ChannelsLast) { |  | ||||||
|       return is_channels_last_contiguous_; |  | ||||||
|     } else if (memory_format == at::MemoryFormat::ChannelsLast3d) { |  | ||||||
|       return is_channels_last_3d_contiguous_; |  | ||||||
|     } |  | ||||||
|     return is_contiguous_; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   bool is_strides_like_default(at::MemoryFormat memory_format) const { |   bool is_strides_like_default(at::MemoryFormat memory_format) const { | ||||||
|     if (has_symbolic_sizes_strides_) { |     if (has_symbolic_sizes_strides_) { | ||||||
|       if (memory_format == at::MemoryFormat::ChannelsLast) { |       if (memory_format == at::MemoryFormat::ChannelsLast) { | ||||||
| @ -873,9 +886,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   SymBool sym_is_non_overlapping_and_dense_default() const { | ||||||
|  |     if (has_symbolic_sizes_strides_) { | ||||||
|  |       return symbolic_shape_meta().is_non_overlapping_and_dense(); | ||||||
|  |     } else { | ||||||
|  |       return is_non_overlapping_and_dense_; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|   bool is_non_overlapping_and_dense_default() const { |   bool is_non_overlapping_and_dense_default() const { | ||||||
|     if (has_symbolic_sizes_strides_) { |     if (has_symbolic_sizes_strides_) { | ||||||
|       return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool( |       return sym_is_non_overlapping_and_dense_default().guard_bool( | ||||||
|           __FILE__, __LINE__); |           __FILE__, __LINE__); | ||||||
|     } else { |     } else { | ||||||
|       return is_non_overlapping_and_dense_; |       return is_non_overlapping_and_dense_; | ||||||
| @ -968,9 +989,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|    * for a tensor to have rank, but not well defined sizes. |    * for a tensor to have rank, but not well defined sizes. | ||||||
|    */ |    */ | ||||||
|   // sizes_strides_policy_ >= CustomStrides |   // sizes_strides_policy_ >= CustomStrides | ||||||
|   virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const; |  | ||||||
|   virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; |   virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const; | ||||||
|   virtual bool is_non_overlapping_and_dense_custom() const; |  | ||||||
|  |   virtual c10::SymBool sym_is_non_overlapping_and_dense_custom() const; | ||||||
|  |  | ||||||
|  |   bool is_non_overlapping_and_dense_custom() const { | ||||||
|  |     return sym_is_non_overlapping_and_dense_custom().guard_bool( | ||||||
|  |         __FILE__, __LINE__); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   virtual c10::SymBool sym_is_contiguous_custom( | ||||||
|  |       at::MemoryFormat memory_format) const; | ||||||
|  |  | ||||||
|  |   bool is_contiguous_custom(at::MemoryFormat memory_format) const { | ||||||
|  |     return sym_is_contiguous_custom(memory_format) | ||||||
|  |         .guard_bool(__FILE__, __LINE__); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   // sizes_strides_policy_ >= CustomSizes |   // sizes_strides_policy_ >= CustomSizes | ||||||
|   // Currently this method only exists to be overwritten by subclasses such as |   // Currently this method only exists to be overwritten by subclasses such as | ||||||
|   // NestedTensorImpl. |   // NestedTensorImpl. | ||||||
| @ -1004,9 +1040,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|   virtual c10::SymInt sym_storage_offset_custom() const; |   virtual c10::SymInt sym_storage_offset_custom() const; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
|   /** | /** | ||||||
|    * True if this tensor has storage. See storage() for details. |  * True if this tensor has storage. See storage() for details. | ||||||
|    */ |  */ | ||||||
| #ifdef DEBUG | #ifdef DEBUG | ||||||
|   // Allow subclasses to check that their storage_ is never getting set in debug |   // Allow subclasses to check that their storage_ is never getting set in debug | ||||||
|   // builds. |   // builds. | ||||||
| @ -1016,11 +1052,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
| #endif | #endif | ||||||
|       bool |       bool | ||||||
|       has_storage() const |       has_storage() const | ||||||
|   // NOTE: we devirtualize this because it arguably shouldn't be an | // NOTE: we devirtualize this because it arguably shouldn't be an | ||||||
|   // error just to ask subclasses if they have storage. | // error just to ask subclasses if they have storage. | ||||||
|   // This used to throw for most subclasses, but OpaqueTensorImpl | // This used to throw for most subclasses, but OpaqueTensorImpl | ||||||
|   // wanted it to successfully return false, so we went ahead and made | // wanted it to successfully return false, so we went ahead and made | ||||||
|   // it a non-error. | // it a non-error. | ||||||
| #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY | #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY | ||||||
|   { |   { | ||||||
|     return storage_; |     return storage_; | ||||||
| @ -2447,6 +2483,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|     return is_strides_like(at::MemoryFormat::ChannelsLast3d); |     return is_strides_like(at::MemoryFormat::ChannelsLast3d); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   bool is_non_overlapping_and_dense_or_false() const { | ||||||
|  |     return sym_is_non_overlapping_and_dense().guard_or_false( | ||||||
|  |         __FILE__, __LINE__); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   bool is_non_overlapping_and_dense() const { |   bool is_non_overlapping_and_dense() const { | ||||||
|     if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { |     if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { | ||||||
|       return is_non_overlapping_and_dense_custom(); |       return is_non_overlapping_and_dense_custom(); | ||||||
| @ -2454,6 +2495,13 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |||||||
|     return is_non_overlapping_and_dense_default(); |     return is_non_overlapping_and_dense_default(); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  |   SymBool sym_is_non_overlapping_and_dense() const { | ||||||
|  |     if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) { | ||||||
|  |       return sym_is_non_overlapping_and_dense_custom(); | ||||||
|  |     } | ||||||
|  |     return sym_is_non_overlapping_and_dense_default(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|   // if this returns true, then it is guaranteed that this tensor has symbolic |   // if this returns true, then it is guaranteed that this tensor has symbolic | ||||||
|   // sizes/strides |   // sizes/strides | ||||||
|   bool has_symbolic_sizes_strides() const { |   bool has_symbolic_sizes_strides() const { | ||||||
|  | |||||||
| @ -12,7 +12,8 @@ UndefinedTensorImpl::UndefinedTensorImpl() | |||||||
|   set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); |   set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); | ||||||
| } | } | ||||||
|  |  | ||||||
| bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const { | c10::SymBool UndefinedTensorImpl::sym_is_contiguous_custom( | ||||||
|  |     MemoryFormat format) const { | ||||||
|   return is_contiguous_default(format); |   return is_contiguous_default(format); | ||||||
| } | } | ||||||
| IntArrayRef UndefinedTensorImpl::strides_custom() const { | IntArrayRef UndefinedTensorImpl::strides_custom() const { | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ struct C10_API UndefinedTensorImpl final : public TensorImpl { | |||||||
|   void set_storage_offset(int64_t offset) override; |   void set_storage_offset(int64_t offset) override; | ||||||
|  |  | ||||||
|  protected: |  protected: | ||||||
|   bool is_contiguous_custom(MemoryFormat format) const override; |   c10::SymBool sym_is_contiguous_custom(MemoryFormat format) const override; | ||||||
|   IntArrayRef strides_custom() const override; |   IntArrayRef strides_custom() const override; | ||||||
|   SymIntArrayRef sym_strides_custom() const override; |   SymIntArrayRef sym_strides_custom() const override; | ||||||
|  |  | ||||||
|  | |||||||
| @ -15467,6 +15467,48 @@ class TestExportCustomClass(TorchTestCase): | |||||||
|             MyModel(), inps, dynamic_shapes=spec, strict=True |             MyModel(), inps, dynamic_shapes=spec, strict=True | ||||||
|         ).run_decompositions({}) |         ).run_decompositions({}) | ||||||
|  |  | ||||||
|  |     def test_unbacked_contiguous(self): | ||||||
|  |         class MyModel(torch.nn.Module): | ||||||
|  |             def forward(self, x, mask): | ||||||
|  |                 masked_select = x.masked_select(mask) | ||||||
|  |                 view = masked_select.view(-1, 1548) | ||||||
|  |                 contig = view.contiguous() | ||||||
|  |                 return contig + 1 | ||||||
|  |  | ||||||
|  |         example_inputs = ( | ||||||
|  |             torch.randn((768, 1548), dtype=torch.bfloat16), | ||||||
|  |             torch.randint(low=0, high=1, size=(768, 1), dtype=torch.bool), | ||||||
|  |         ) | ||||||
|  |         spec = { | ||||||
|  |             "x": [Dim.STATIC, Dim.STATIC], | ||||||
|  |             "mask": [Dim.STATIC, Dim.STATIC], | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         traced = export(MyModel(), example_inputs, strict=True) | ||||||
|  |         self.assertExpectedInline( | ||||||
|  |             traced.graph_module.code, | ||||||
|  |             """\ | ||||||
|  | def forward(self, x, mask): | ||||||
|  |     masked_select = torch.ops.aten.masked_select.default(x, mask);  x = mask = None | ||||||
|  |     sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0) | ||||||
|  |     sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1);  sym_constrain_range_for_size_default = None | ||||||
|  |     ge = sym_size_int_1 >= 0 | ||||||
|  |     _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None | ||||||
|  |     le = sym_size_int_1 <= 1188864 | ||||||
|  |     _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 1188864 on node 'le'");  le = _assert_scalar_default_1 = None | ||||||
|  |     mod = sym_size_int_1 % 1548 | ||||||
|  |     eq_2 = mod == 0;  mod = None | ||||||
|  |     _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(Mod(u0, 1548), 0) on node 'eq_2'");  eq_2 = _assert_scalar_default_2 = None | ||||||
|  |     floordiv = sym_size_int_1 // 1548 | ||||||
|  |     mul_2 = 1548 * floordiv;  floordiv = None | ||||||
|  |     eq_3 = sym_size_int_1 == mul_2;  sym_size_int_1 = mul_2 = None | ||||||
|  |     _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(eq_3, "Runtime assertion failed for expression Eq(u0, 1548*((u0//1548))) on node 'eq_3'");  eq_3 = _assert_scalar_default_3 = None | ||||||
|  |     view = torch.ops.aten.view.default(masked_select, [-1, 1548]);  masked_select = None | ||||||
|  |     add = torch.ops.aten.add.Tensor(view, 1);  view = None | ||||||
|  |     return (add,)""", | ||||||
|  |             ignore_empty_lines=True, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     run_tests() |     run_tests() | ||||||
|  | |||||||
| @ -3336,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", | |||||||
|         _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'");  eq = _assert_scalar_4 = None |         _assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'");  eq = _assert_scalar_4 = None | ||||||
|         clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format);  arg3_1 = None |         clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format);  arg3_1 = None | ||||||
|         view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]);  clone = _local_scalar_dense = _local_scalar_dense_1 = None |         view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]);  clone = _local_scalar_dense = _local_scalar_dense_1 = None | ||||||
|         mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10);  view = None |         mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10);  view = None | ||||||
|         return (mul_19,)""",  # noqa: B950 |         return (mul_21,)""",  # noqa: B950 | ||||||
|             ignore_comments=True, |             ignore_comments=True, | ||||||
|             ignore_empty_lines=True, |             ignore_empty_lines=True, | ||||||
|         ) |         ) | ||||||
| @ -3460,6 +3460,75 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", | |||||||
|         func(torch.ones(5, 6, 9, 8)) |         func(torch.ones(5, 6, 9, 8)) | ||||||
|         self.assertEqual(cnt.frame_count, 3) |         self.assertEqual(cnt.frame_count, 3) | ||||||
|  |  | ||||||
|  |     @skipIfTorchDynamo("not allowed to trace mark_unbacked") | ||||||
|  |     @fresh_cache() | ||||||
|  |     def test_unbacked_contiguous(self): | ||||||
|  |         cnt = CompileCounterWithBackend("inductor") | ||||||
|  |  | ||||||
|  |         def func(x): | ||||||
|  |             contig = x.contiguous() | ||||||
|  |             return (contig + 1) * 100 | ||||||
|  |  | ||||||
|  |         compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) | ||||||
|  |  | ||||||
|  |         x = torch.randn(10, 10) | ||||||
|  |         # make x not contiguous. | ||||||
|  |         x = x.t_() | ||||||
|  |         torch._dynamo.decorators.mark_unbacked(x, 0) | ||||||
|  |         torch._dynamo.decorators.mark_unbacked(x, 1) | ||||||
|  |         log_stream, ctx = logs_to_string( | ||||||
|  |             "torch._inductor.compile_fx", "post_grad_graphs" | ||||||
|  |         ) | ||||||
|  |         with ctx(): | ||||||
|  |             compiled_func(x) | ||||||
|  |             self.assertEqual(compiled_func(x), func(x)) | ||||||
|  |             y = torch.rand(20, 20).t() | ||||||
|  |             self.assertEqual(compiled_func(y), func(y)) | ||||||
|  |             self.assertEqual(cnt.frame_count, 1) | ||||||
|  |         output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() | ||||||
|  |         self.assertExpectedInline( | ||||||
|  |             output, | ||||||
|  |             """\ | ||||||
|  |         ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0;  arg0_1 = None | ||||||
|  |         _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge_1 = _assert_scalar = None | ||||||
|  |         ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0;  arg1_1 = None | ||||||
|  |         _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'");  ge_3 = _assert_scalar_1 = None | ||||||
|  |         clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format);  arg2_1 = None | ||||||
|  |         add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1);  clone = None | ||||||
|  |         mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100);  add_3 = None | ||||||
|  |         return (mul_6,)""",  # noqa: B950 | ||||||
|  |             ignore_comments=True, | ||||||
|  |             ignore_empty_lines=True, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         log_stream, ctx = logs_to_string( | ||||||
|  |             "torch._inductor.compile_fx", "post_grad_graphs" | ||||||
|  |         ) | ||||||
|  |         with ctx(): | ||||||
|  |             # recompilation will happen due to stride specialization. | ||||||
|  |             y = torch.rand(20, 20) | ||||||
|  |             torch._dynamo.decorators.mark_unbacked(y, 0) | ||||||
|  |             torch._dynamo.decorators.mark_unbacked(y, 1) | ||||||
|  |             self.assertEqual(compiled_func(y), func(y)) | ||||||
|  |             self.assertEqual(cnt.frame_count, 2) | ||||||
|  |  | ||||||
|  |         output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() | ||||||
|  |  | ||||||
|  |         # No clone this time since input is contiguous. | ||||||
|  |         self.assertExpectedInline( | ||||||
|  |             output, | ||||||
|  |             """\ | ||||||
|  |         ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0;  arg0_1 = None | ||||||
|  |         _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge_1 = _assert_scalar = None | ||||||
|  |         ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0;  arg1_1 = None | ||||||
|  |         _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'");  ge_3 = _assert_scalar_1 = None | ||||||
|  |         add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1);  arg2_1 = None | ||||||
|  |         mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100);  add = None | ||||||
|  |         return (mul_5,)""",  # noqa: B950 | ||||||
|  |             ignore_comments=True, | ||||||
|  |             ignore_empty_lines=True, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| instantiate_parametrized_tests(TestUnbacked) | instantiate_parametrized_tests(TestUnbacked) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1370,8 +1370,8 @@ def forward(self, crop_camera_1, mask_1): | |||||||
|     view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  expand_1 = sym_size_int_1 = sym_size_int_2 = None |     view_1 = torch.ops.aten.view.default(expand_1, [sym_size_int, sym_size_int_1, sym_size_int_2]);  expand_1 = sym_size_int_1 = sym_size_int_2 = None | ||||||
|     bmm = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None |     bmm = torch.ops.aten.bmm.default(view, view_1);  view = view_1 = None | ||||||
|     view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]);  bmm = None |     view_2 = torch.ops.aten.view.default(bmm, [sym_size_int, 3, 3]);  bmm = None | ||||||
|     mul_6 = sym_size_int * 3 |     mul_9 = sym_size_int * 3 | ||||||
|     view_3 = torch.ops.aten.view.default(view_2, [mul_6, 3]);  view_2 = mul_6 = None |     view_3 = torch.ops.aten.view.default(view_2, [mul_9, 3]);  view_2 = mul_9 = None | ||||||
|     mm = torch.ops.aten.mm.default(view_3, eye);  view_3 = eye = None |     mm = torch.ops.aten.mm.default(view_3, eye);  view_3 = eye = None | ||||||
|     _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]);  mm = sym_size_int = None |     _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]);  mm = sym_size_int = None | ||||||
|     index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view);  crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None |     index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view);  crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None | ||||||
|  | |||||||
| @ -264,7 +264,7 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec | |||||||
|   auto& self_ = THPVariable_Unpack(self); |   auto& self_ = THPVariable_Unpack(self); | ||||||
|   auto memory_format = r.memoryformat(0); |   auto memory_format = r.memoryformat(0); | ||||||
|   // avoids touching the GIL or current device if self is already contiguous |   // avoids touching the GIL or current device if self is already contiguous | ||||||
|   if (self_.is_contiguous(memory_format)) { |   if (self_.is_contiguous_or_false(memory_format)) { | ||||||
|     // NOTE: this logic is duplicated from VariableType.cpp. Since we need to |     // NOTE: this logic is duplicated from VariableType.cpp. Since we need to | ||||||
|     // record this call to contiguous() in the trace regardless of whether |     // record this call to contiguous() in the trace regardless of whether | ||||||
|     // we actually call contiguous here, we need to record this information |     // we actually call contiguous here, we need to record this information | ||||||
|  | |||||||
| @ -195,13 +195,14 @@ bool LTCTensorImpl::is_strides_like_custom( | |||||||
|   return false; |   return false; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool LTCTensorImpl::is_non_overlapping_and_dense_custom() const { | c10::SymBool LTCTensorImpl::sym_is_non_overlapping_and_dense_custom() const { | ||||||
|   // This should be true, but false as a temporary fix for a PyTorch core issue, |   // This should be true, but false as a temporary fix for a PyTorch core issue, | ||||||
|   // according to https://github.com/pytorch/xla/pull/2682. |   // according to https://github.com/pytorch/xla/pull/2682. | ||||||
|   return false; |   return false; | ||||||
| } | } | ||||||
|  |  | ||||||
| bool LTCTensorImpl::is_contiguous_custom(c10::MemoryFormat _unused) const { | c10::SymBool LTCTensorImpl::sym_is_contiguous_custom( | ||||||
|  |     c10::MemoryFormat _unused) const { | ||||||
|   // TODO(ezyang): I don't think this branch is actually necessary |   // TODO(ezyang): I don't think this branch is actually necessary | ||||||
|   // TODO(ezyang): I don't think this logic is right, shouldn't we pass on |   // TODO(ezyang): I don't think this logic is right, shouldn't we pass on | ||||||
|   // the memory format? |   // the memory format? | ||||||
|  | |||||||
| @ -41,10 +41,11 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl { | |||||||
|   int64_t numel_custom() const override; |   int64_t numel_custom() const override; | ||||||
|   int64_t storage_offset_custom() const override; |   int64_t storage_offset_custom() const override; | ||||||
|   int64_t dim_custom() const override; |   int64_t dim_custom() const override; | ||||||
|   bool is_contiguous_custom(at::MemoryFormat memory_format) const override; |  | ||||||
|   bool is_strides_like_custom(at::MemoryFormat memory_format) const override; |   bool is_strides_like_custom(at::MemoryFormat memory_format) const override; | ||||||
|   bool is_non_overlapping_and_dense_custom() const override; |   c10::SymBool sym_is_non_overlapping_and_dense_custom() const override; | ||||||
|  |  | ||||||
|  |   c10::SymBool sym_is_contiguous_custom( | ||||||
|  |       at::MemoryFormat memory_format) const override; | ||||||
|   c10::SymIntArrayRef sym_sizes_custom() const override; |   c10::SymIntArrayRef sym_sizes_custom() const override; | ||||||
|   c10::SymIntArrayRef sym_strides_custom() const override; |   c10::SymIntArrayRef sym_strides_custom() const override; | ||||||
|   c10::SymInt sym_numel_custom() const override; |   c10::SymInt sym_numel_custom() const override; | ||||||
|  | |||||||
| @ -20,6 +20,9 @@ class ExprPrinter(StrPrinter): | |||||||
|     def _print_Mul(self, expr: sympy.Expr) -> str: |     def _print_Mul(self, expr: sympy.Expr) -> str: | ||||||
|         return self.stringify(expr.args, "*", precedence(expr)) |         return self.stringify(expr.args, "*", precedence(expr)) | ||||||
|  |  | ||||||
|  |     def _print_Not(self, expr: sympy.Expr) -> str: | ||||||
|  |         return f"not ({self._print(expr.args[0])})" | ||||||
|  |  | ||||||
|     def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: |     def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: | ||||||
|         return self.stringify(expr.args, " + ", precedence(expr)) |         return self.stringify(expr.args, " + ", precedence(expr)) | ||||||
|  |  | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user