mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-01 13:34:57 +08:00 
			
		
		
		
	Update on "[WIP] Add a simple cache mechanism to accelerate torch.compile-for-eager"
This PR is a follow-up of RFC https://github.com/pytorch/pytorch/issues/115545. In this PR, we are trying to provide a cache mechanism to accelerate torch.compile-for-eager. cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
This commit is contained in:
		| @ -205,6 +205,7 @@ class Context final { | ||||
| class UniformParamsBuffer final { | ||||
|  private: | ||||
|   Context* context_p_; | ||||
|   size_t nbytes_; | ||||
|   VulkanBuffer vulkan_buffer_; | ||||
|  | ||||
|  public: | ||||
| @ -213,6 +214,7 @@ class UniformParamsBuffer final { | ||||
|   template <typename Block> | ||||
|   UniformParamsBuffer(Context* context_p, const Block& block) | ||||
|       : context_p_(context_p), | ||||
|         nbytes_(sizeof(block)), | ||||
|         vulkan_buffer_( | ||||
|             context_p_->adapter_ptr()->vma().create_params_buffer(block)) {} | ||||
|  | ||||
| @ -231,6 +233,21 @@ class UniformParamsBuffer final { | ||||
|   VulkanBuffer& buffer() { | ||||
|     return vulkan_buffer_; | ||||
|   } | ||||
|  | ||||
|   template <typename Block> | ||||
|   void update(const Block& block) { | ||||
|     if (sizeof(block) != nbytes_) { | ||||
|       VK_THROW( | ||||
|           "Attempted to update UniformParamsBuffer with data of different size"); | ||||
|     } | ||||
|     // Fill the uniform buffer with data in block | ||||
|     { | ||||
|       MemoryMap mapping(vulkan_buffer_, MemoryAccessType::WRITE); | ||||
|       Block* data_ptr = mapping.template data<Block>(); | ||||
|  | ||||
|       *data_ptr = block; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| class StorageBuffer final { | ||||
| @ -238,6 +255,7 @@ class StorageBuffer final { | ||||
|   Context* context_p_; | ||||
|   ScalarType dtype_; | ||||
|   size_t numel_; | ||||
|   size_t nbytes_; | ||||
|   VulkanBuffer vulkan_buffer_; | ||||
|  | ||||
|  public: | ||||
| @ -249,8 +267,9 @@ class StorageBuffer final { | ||||
|       : context_p_(context_p), | ||||
|         dtype_(dtype), | ||||
|         numel_(numel), | ||||
|         nbytes_(element_size(dtype_) * numel_), | ||||
|         vulkan_buffer_(context_p_->adapter_ptr()->vma().create_storage_buffer( | ||||
|             element_size(dtype_) * numel_, | ||||
|             nbytes_, | ||||
|             gpuonly)) {} | ||||
|  | ||||
|   StorageBuffer(const StorageBuffer&) = delete; | ||||
| @ -270,6 +289,14 @@ class StorageBuffer final { | ||||
|   inline VulkanBuffer& buffer() { | ||||
|     return vulkan_buffer_; | ||||
|   } | ||||
|  | ||||
|   inline size_t numel() { | ||||
|     return numel_; | ||||
|   } | ||||
|  | ||||
|   inline size_t nbytes() { | ||||
|     return nbytes_; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| bool available(); | ||||
|  | ||||
| @ -151,6 +151,10 @@ class VulkanBuffer final { | ||||
|     return (memory_.allocation != VK_NULL_HANDLE); | ||||
|   } | ||||
|  | ||||
|   inline bool owns_memory() const { | ||||
|     return owns_memory_; | ||||
|   } | ||||
|  | ||||
|   operator bool() const { | ||||
|     return (handle_ != VK_NULL_HANDLE); | ||||
|   } | ||||
| @ -372,6 +376,10 @@ class VulkanImage final { | ||||
|     return (memory_.allocation != VK_NULL_HANDLE); | ||||
|   } | ||||
|  | ||||
|   inline bool owns_memory() const { | ||||
|     return owns_memory_; | ||||
|   } | ||||
|  | ||||
|   inline operator bool() const { | ||||
|     return (handles_.image != VK_NULL_HANDLE); | ||||
|   } | ||||
|  | ||||
| @ -12,6 +12,9 @@ | ||||
| #define VK_KERNEL(shader_name) \ | ||||
|   ::at::native::vulkan::api::shader_registry().get_shader_info(#shader_name) | ||||
|  | ||||
| #define VK_KERNEL_FROM_STR(shader_name_str) \ | ||||
|   ::at::native::vulkan::api::shader_registry().get_shader_info(shader_name_str) | ||||
|  | ||||
| namespace at { | ||||
| namespace native { | ||||
| namespace vulkan { | ||||
|  | ||||
| @ -318,8 +318,8 @@ api::UniformParamsBuffer make_metadata_uniform( | ||||
|   } | ||||
|  | ||||
|   vTensor::BufferMetadata metadata{ | ||||
|       api::utils::make_nchw_uvec4(sizes), | ||||
|       api::utils::make_nchw_uvec4(strides), | ||||
|       api::utils::make_whcn_uvec4(sizes), | ||||
|       api::utils::make_whcn_uvec4(strides), | ||||
|       api::utils::safe_downcast<uint32_t>(sizes.size()), | ||||
|       api::utils::safe_downcast<uint32_t>(api::utils::multiply_integers(sizes)), | ||||
|   }; | ||||
| @ -347,12 +347,13 @@ vTensor::vTensor( | ||||
|       strides_{calc_strides(sizes, memory_layout_, storage_type)}, | ||||
|       gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)}, | ||||
|       gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)}, | ||||
|       // Vulkan uniform buffer containing sizes and stride info | ||||
|       metadata_uniform_{make_metadata_uniform( | ||||
|           context, | ||||
|           gpu_sizes_, | ||||
|           gpu_strides_, | ||||
|           storage_type)}, | ||||
|       virtual_extents_( | ||||
|           create_image_extents(gpu_sizes_, storage_type, memory_layout)), | ||||
|       // Utility Uniform Buffers that can be passed to shaders as arguments | ||||
|       metadata_uniform_(), | ||||
|       cpu_sizes_uniform_(nullptr), | ||||
|       gpu_sizes_uniform_(nullptr), | ||||
|       extents_uniform_(nullptr), | ||||
|       // Construct Tensor storage | ||||
|       view_(std::make_shared<vTensorStorage>( | ||||
|           context, | ||||
| @ -377,12 +378,13 @@ vTensor::vTensor( | ||||
|       strides_{calc_strides(sizes, memory_layout_, storage_type)}, | ||||
|       gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)}, | ||||
|       gpu_strides_{calc_strides(gpu_sizes_, memory_layout_, storage_type)}, | ||||
|       virtual_extents_( | ||||
|           create_image_extents(gpu_sizes_, storage_type, memory_layout)), | ||||
|       // Vulkan uniform buffer containing sizes and stride info | ||||
|       metadata_uniform_{make_metadata_uniform( | ||||
|           context, | ||||
|           gpu_sizes_, | ||||
|           gpu_strides_, | ||||
|           storage_type)}, | ||||
|       metadata_uniform_(), | ||||
|       cpu_sizes_uniform_(nullptr), | ||||
|       gpu_sizes_uniform_(nullptr), | ||||
|       extents_uniform_(nullptr), | ||||
|       // Quantization params | ||||
|       is_quantized_{true}, | ||||
|       q_scale_{q_scale}, | ||||
| @ -425,10 +427,47 @@ api::VulkanBuffer& vTensor::buffer( | ||||
|   return view_->buffer_; | ||||
| } | ||||
|  | ||||
| api::VulkanBuffer& vTensor::buffer_metadata() { | ||||
|   if (!metadata_uniform_.buffer()) { | ||||
|     metadata_uniform_ = make_metadata_uniform( | ||||
|         view_->context_, gpu_sizes_, gpu_strides_, storage_type()); | ||||
|   } | ||||
|   return metadata_uniform_.buffer(); | ||||
| } | ||||
|  | ||||
| std::shared_ptr<api::UniformParamsBuffer> vTensor::cpu_sizes_ubo() { | ||||
|   if (!cpu_sizes_uniform_) { | ||||
|     cpu_sizes_uniform_.reset(new api::UniformParamsBuffer( | ||||
|         view_->context_, api::utils::make_whcn_ivec4(sizes_))); | ||||
|   } | ||||
|   return cpu_sizes_uniform_; | ||||
| } | ||||
|  | ||||
| std::shared_ptr<api::UniformParamsBuffer> vTensor::gpu_sizes_ubo() { | ||||
|   if (!gpu_sizes_uniform_) { | ||||
|     gpu_sizes_uniform_.reset(new api::UniformParamsBuffer( | ||||
|         view_->context_, api::utils::make_whcn_ivec4(gpu_sizes_))); | ||||
|   } | ||||
|   return gpu_sizes_uniform_; | ||||
| } | ||||
|  | ||||
| std::shared_ptr<api::UniformParamsBuffer> vTensor::extents_ubo() { | ||||
|   if (!extents_uniform_) { | ||||
|     extents_uniform_.reset(new api::UniformParamsBuffer( | ||||
|         view_->context_, | ||||
|         api::utils::uvec4( | ||||
|             {view_->extents_.data[0], | ||||
|              view_->extents_.data[1], | ||||
|              view_->extents_.data[2], | ||||
|              1u}))); | ||||
|   } | ||||
|   return extents_uniform_; | ||||
| } | ||||
|  | ||||
| vTensor::BufferMetadata vTensor::get_cpu_buffer_metadata() const { | ||||
|   return { | ||||
|       api::utils::make_nchw_uvec4(sizes_), | ||||
|       api::utils::make_nchw_uvec4(strides_), | ||||
|       api::utils::make_whcn_uvec4(sizes_), | ||||
|       api::utils::make_whcn_uvec4(strides_), | ||||
|       api::utils::safe_downcast<uint32_t>(sizes_.size()), | ||||
|       api::utils::safe_downcast<uint32_t>( | ||||
|           api::utils::multiply_integers(sizes_)), | ||||
| @ -473,6 +512,65 @@ void vTensor::bind_allocation(const api::MemoryAllocation& allocation) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| void vTensor::update_size_metadata(const std::vector<int64_t>& new_sizes) { | ||||
|   sizes_ = new_sizes; | ||||
|   gpu_sizes_ = calc_gpu_sizes(sizes_, memory_layout_, storage_type()); | ||||
|   virtual_extents_ = | ||||
|       create_image_extents(gpu_sizes_, storage_type(), memory_layout_); | ||||
|  | ||||
|   if (cpu_sizes_uniform_) { | ||||
|     cpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(sizes_)); | ||||
|   } | ||||
|  | ||||
|   if (gpu_sizes_uniform_) { | ||||
|     gpu_sizes_uniform_->update(api::utils::make_whcn_ivec4(gpu_sizes_)); | ||||
|   } | ||||
|  | ||||
|   if (extents_uniform_) { | ||||
|     extents_uniform_->update(api::utils::uvec4( | ||||
|         {virtual_extents_.data[0], | ||||
|          virtual_extents_.data[1], | ||||
|          virtual_extents_.data[2], | ||||
|          1u})); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void vTensor::reallocate(const std::vector<int64_t>& new_sizes) { | ||||
|   update_size_metadata(new_sizes); | ||||
|   view_->discard_and_reallocate( | ||||
|       calc_gpu_sizes(new_sizes, memory_layout_, storage_type()), | ||||
|       memory_layout_, | ||||
|       dtype_); | ||||
| } | ||||
|  | ||||
| void vTensor::virtual_resize(const std::vector<int64_t>& new_sizes) { | ||||
|   update_size_metadata(new_sizes); | ||||
|   if (storage_type() == api::StorageType::BUFFER) { | ||||
|     if (gpu_nbytes() > view_->buffer_.mem_size()) { | ||||
|       VK_THROW( | ||||
|           "Cannot virtual_resize a vTensor with sizes that require a larger " | ||||
|           "buffer! reallocate() should be used instead."); | ||||
|     } | ||||
|   } else { | ||||
|     bool valid_resize = true; | ||||
|     if (virtual_extents_.data[0] > view_->extents_.data[0]) { | ||||
|       valid_resize = false; | ||||
|     } | ||||
|     if (virtual_extents_.data[1] > view_->extents_.data[1]) { | ||||
|       valid_resize = false; | ||||
|     } | ||||
|     if (virtual_extents_.data[2] > view_->extents_.data[2]) { | ||||
|       valid_resize = false; | ||||
|     } | ||||
|  | ||||
|     if (!valid_resize) { | ||||
|       VK_THROW( | ||||
|           "Cannot virtual_resize a vTensor with sizes that require a larger " | ||||
|           "image texture! reallocate() should be used instead."); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| // | ||||
| // vTensorStorage | ||||
| // | ||||
| @ -569,11 +667,16 @@ vTensorStorage::vTensorStorage( | ||||
|       last_access_{} {} | ||||
|  | ||||
| vTensorStorage::~vTensorStorage() { | ||||
|   flush(); | ||||
| } | ||||
|  | ||||
| void vTensorStorage::flush() { | ||||
|   if (image_) { | ||||
|     context_->register_image_cleanup(image_); | ||||
|   } else if (buffer_) { | ||||
|     context_->register_buffer_cleanup(buffer_); | ||||
|   } | ||||
|   last_access_ = {}; | ||||
| } | ||||
|  | ||||
| void vTensorStorage::transition( | ||||
| @ -663,6 +766,28 @@ void add_buffer_barrier( | ||||
|   } | ||||
| } | ||||
|  | ||||
| void vTensorStorage::discard_and_reallocate( | ||||
|     const std::vector<int64_t>& gpu_sizes, | ||||
|     const api::GPUMemoryLayout gpu_memory_layout, | ||||
|     const api::ScalarType dtype) { | ||||
|   const bool image_owns_memory = image_.owns_memory(); | ||||
|   const bool buffer_owns_memory = buffer_.owns_memory(); | ||||
|  | ||||
|   flush(); | ||||
|  | ||||
|   extents_ = create_image_extents(gpu_sizes, storage_type_, gpu_memory_layout); | ||||
|   image_ = allocate_image( | ||||
|       context_, | ||||
|       extents_, | ||||
|       storage_type_, | ||||
|       api::to_vkformat(dtype), | ||||
|       image_owns_memory); | ||||
|  | ||||
|   buffer_length_ = api::utils::multiply_integers(gpu_sizes); | ||||
|   buffer_ = allocate_buffer( | ||||
|       context_, buffer_length_, storage_type_, dtype, buffer_owns_memory); | ||||
| } | ||||
|  | ||||
| } // namespace vulkan | ||||
| } // namespace native | ||||
| } // namespace at | ||||
|  | ||||
| @ -66,6 +66,9 @@ class vTensorStorage final { | ||||
|   LastAccess last_access_; | ||||
|  | ||||
|  private: | ||||
|   // Registers underlying memory for cleanup | ||||
|   void flush(); | ||||
|  | ||||
|   // Memory barrier insertion | ||||
|   void transition( | ||||
|       api::PipelineBarrier&, | ||||
| @ -79,6 +82,11 @@ class vTensorStorage final { | ||||
|   inline VkFormat texture_format() { | ||||
|     return image_.format(); | ||||
|   } | ||||
|  | ||||
|   void discard_and_reallocate( | ||||
|       const std::vector<int64_t>& gpu_sizes, | ||||
|       const api::GPUMemoryLayout gpu_memory_layout, | ||||
|       const api::ScalarType dtype); | ||||
| }; | ||||
|  | ||||
| class vTensor final { | ||||
| @ -141,10 +149,29 @@ class vTensor final { | ||||
|   std::vector<int64_t> gpu_sizes_; | ||||
|   std::vector<int64_t> gpu_strides_; | ||||
|  | ||||
|   // The extents that correspond to the tensor's size metadata. Note that this | ||||
|   // may not be the same as the extents of the underlying image texture because | ||||
|   // vTensor can be virtually resized via virtual_resize() which will cause it | ||||
|   // to be interpreted as a tensor with a different size. | ||||
|   api::utils::uvec3 virtual_extents_; | ||||
|  | ||||
|   // A Vulkan uniform buffer containing sizes and strides of the GPU buffer that | ||||
|   // can be passed into a shader. | ||||
|   api::UniformParamsBuffer metadata_uniform_; | ||||
|  | ||||
|   // A Vulkan uniform buffer containing the tensor sizes that can be passed into | ||||
|   // a shader. | ||||
|   std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_uniform_; | ||||
|  | ||||
|   // A Vulkan uniform buffer containing the GPU tensor sizes that can be passed | ||||
|   // into a shader. GPU sizes refers to the sizes of the tensor after padding | ||||
|   // has been applied to one dimension to align it to the next multiple of 4. | ||||
|   std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_uniform_; | ||||
|  | ||||
|   // A Vulkan uniform buffer containing the image extents of the underlying | ||||
|   // image texture that can be passed into a shader. | ||||
|   std::shared_ptr<api::UniformParamsBuffer> extents_uniform_; | ||||
|  | ||||
|   // Quantization params | ||||
|   bool is_quantized_{false}; | ||||
|   double q_scale_{1.0f}; | ||||
| @ -250,13 +277,36 @@ class vTensor final { | ||||
|     return gpu_strides_; | ||||
|   } | ||||
|  | ||||
|   inline const api::utils::uvec3& virtual_extents() const { | ||||
|     return virtual_extents_; | ||||
|   } | ||||
|  | ||||
|   /* | ||||
|    * Get a uniform buffer containing sizes and strides information of the GPU | ||||
|    * buffer | ||||
|    */ | ||||
|   inline api::VulkanBuffer& buffer_metadata() { | ||||
|     return metadata_uniform_.buffer(); | ||||
|   } | ||||
|   api::VulkanBuffer& buffer_metadata(); | ||||
|  | ||||
|   /* | ||||
|    * Get a uniform buffer object containing the tensor sizes to use in a compute | ||||
|    * shader. Note that the UBO will be created the first time this function is | ||||
|    * called. | ||||
|    */ | ||||
|   std::shared_ptr<api::UniformParamsBuffer> cpu_sizes_ubo(); | ||||
|  | ||||
|   /* | ||||
|    * Get a uniform buffer object containing the tensor GPU sizes to use in a | ||||
|    * compute shader. Note that the UBO will be created the first time this | ||||
|    * function is called. | ||||
|    */ | ||||
|   std::shared_ptr<api::UniformParamsBuffer> gpu_sizes_ubo(); | ||||
|  | ||||
|   /* | ||||
|    * Get a uniform buffer object containing the image extents to use in a | ||||
|    * compute shader. Note that the UBO will be created the first time this | ||||
|    * function is called. | ||||
|    */ | ||||
|   std::shared_ptr<api::UniformParamsBuffer> extents_ubo(); | ||||
|  | ||||
|   /* | ||||
|    * Constructs a BufferMetdata struct based on the original sizes and strides | ||||
| @ -308,7 +358,7 @@ class vTensor final { | ||||
|    * Returns numel but based on gpu_sizes_ instead of sizes_ | ||||
|    */ | ||||
|   inline size_t gpu_numel() const { | ||||
|     return view_->buffer_length_; | ||||
|     return api::utils::multiply_integers(gpu_sizes_); | ||||
|   } | ||||
|  | ||||
|   /* | ||||
| @ -332,6 +382,27 @@ class vTensor final { | ||||
|    * Binds the underlying resource to the given memory allocation | ||||
|    */ | ||||
|   void bind_allocation(const api::MemoryAllocation& allocation); | ||||
|  | ||||
|  private: | ||||
|   /* | ||||
|    * Update the size metadata of the vTensor to be new sizes. Should not be used | ||||
|    * directly, reallocate() or virtual_resize() should be used instead. | ||||
|    */ | ||||
|   void update_size_metadata(const std::vector<int64_t>& new_sizes); | ||||
|  | ||||
|  public: | ||||
|   /* | ||||
|    * Discard the underlying VkImage or VkBuffer and re-allocate based on new | ||||
|    * tensor sizes | ||||
|    */ | ||||
|   void reallocate(const std::vector<int64_t>& new_sizes); | ||||
|  | ||||
|   /* | ||||
|    * Perform a virtual resize of the vTensor by modifying the size metadata that | ||||
|    * gets used in compute shaders. This allows the shader to treat the | ||||
|    * underlying resource as if it were a different size. | ||||
|    */ | ||||
|   void virtual_resize(const std::vector<int64_t>& new_sizes); | ||||
| }; | ||||
|  | ||||
| void add_buffer_barrier( | ||||
|  | ||||
| @ -328,10 +328,10 @@ inline ivec3 make_ivec3(uvec3 ints) { | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Given an vector of up to 4 int64_t representing the sizes of a tensor, | ||||
|  * Given an vector of up to 4 uint64_t representing the sizes of a tensor, | ||||
|  * constructs a uvec4 containing those elements in reverse order. | ||||
|  */ | ||||
| inline uvec4 make_nchw_uvec4(const std::vector<int64_t>& arr) { | ||||
| inline uvec4 make_whcn_uvec4(const std::vector<int64_t>& arr) { | ||||
|   uint32_t w = safe_downcast<uint32_t>(val_at(-1, arr)); | ||||
|   uint32_t h = safe_downcast<uint32_t>(val_at(-2, arr)); | ||||
|   uint32_t c = safe_downcast<uint32_t>(val_at(-3, arr)); | ||||
| @ -340,6 +340,19 @@ inline uvec4 make_nchw_uvec4(const std::vector<int64_t>& arr) { | ||||
|   return {w, h, c, n}; | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Given an vector of up to 4 int64_t representing the sizes of a tensor, | ||||
|  * constructs an ivec4 containing those elements in reverse order. | ||||
|  */ | ||||
| inline ivec4 make_whcn_ivec4(const std::vector<int64_t>& arr) { | ||||
|   int32_t w = val_at(-1, arr); | ||||
|   int32_t h = val_at(-2, arr); | ||||
|   int32_t c = val_at(-3, arr); | ||||
|   int32_t n = val_at(-4, arr); | ||||
|  | ||||
|   return {w, h, c, n}; | ||||
| } | ||||
|  | ||||
| /* | ||||
|  * Wrapper around std::accumulate that accumulates values of a container of | ||||
|  * integral types into int64_t. Taken from `multiply_integers` in | ||||
|  | ||||
| @ -331,7 +331,7 @@ class ProcessGroupNCCLTest(MultiProcessTestCase): | ||||
|             a = torch.tensor([[2, 4, 0], [8, 0, 12]]).to(self.rank) | ||||
|             self.assertEqual(tensor_list[0], a) | ||||
|         except RuntimeError as e: | ||||
|             if "allreduce_sparse is only available in the NCCL experimental branch." in str(e): | ||||
|             if "NCCL does not support all_reduce with sparse tensors" in str(e): | ||||
|                 pass | ||||
|             else: | ||||
|                 # Rethrow the exception if it's a different error | ||||
| @ -4052,7 +4052,7 @@ class SparseCollective(MultiProcessTestCase): | ||||
|             loss.backward() | ||||
|             self.assertTrue(ddp_model.module.embedding.weight.grad.indices, indices) | ||||
|         except RuntimeError as e: | ||||
|             if "allreduce_sparse is only available in the NCCL experimental branch." in str(e): | ||||
|             if "NCCL does not support all_reduce with sparse tensors" in str(e): | ||||
|                 pass | ||||
|             else: | ||||
|                 # Rethrow the exception if it's a different error | ||||
|  | ||||
| @ -972,6 +972,7 @@ class TestExport(TestCase): | ||||
|         self._test_export_same_as_eager(kw_func, args, kwargs) | ||||
|  | ||||
|     @testing.expectedFailureSerDer  # we don't save placeholder metadata | ||||
|     @testing.expectedFailureSerDerPreDispatch | ||||
|     @testing.expectedFailureNonStrict | ||||
|     def test_linear_conv(self): | ||||
|         class MyLinear(torch.nn.Module): | ||||
| @ -1462,6 +1463,7 @@ class TestExport(TestCase): | ||||
|         self.assertEqual(buffer[2].shape, torch.Size([]))  # num_batches_tracked | ||||
|  | ||||
|     @testing.expectedFailureNonStrict | ||||
|     @testing.expectedFailureSerDerPreDispatch  # tracked via: T181382045 | ||||
|     def test_export_dynamo_config(self): | ||||
|         class MyModule(torch.nn.Module): | ||||
|             def __init__(self): | ||||
| @ -1833,6 +1835,7 @@ def forward(self, arg_0): | ||||
|         ) | ||||
|  | ||||
|     @testing.expectedFailureNonStrict  # non-strict does not add deferred runtime assertions | ||||
|     @testing.expectedFailureSerDerPreDispatch  # .item call becomes aten.item in predispatch IR | ||||
|     def test_automatic_constrain_size(self): | ||||
|         class M(torch.nn.Module): | ||||
|             def forward(self, x, y): | ||||
| @ -1888,6 +1891,7 @@ def forward(self, arg_0): | ||||
|                 self.assertTrue(isinstance(node.meta["val"], (Tensor, int))) | ||||
|  | ||||
|     @testing.expectedFailureNonStrict | ||||
|     @testing.expectedFailureSerDerPreDispatch  # .item() becomes aten.item in predispatch IR | ||||
|     def test_export_with_inline_constraints(self): | ||||
|         class Module(torch.nn.Module): | ||||
|             def forward(self, x): | ||||
| @ -2249,6 +2253,7 @@ def forward(self, arg_0): | ||||
|             ) | ||||
|  | ||||
|     @testing.expectedFailureSerDer  # We don't preserve metadata on graph module | ||||
|     @testing.expectedFailureSerDerPreDispatch | ||||
|     @testing.expectedFailureNonStrict | ||||
|     def test_retrace_graph_level_meta_preservation(self): | ||||
|         class Foo(torch.nn.Module): | ||||
| @ -2479,6 +2484,7 @@ def forward(self, arg_0): | ||||
|         ): | ||||
|             exported_program.module()(torch.rand(2, 3), torch.rand(2, 3)) | ||||
|  | ||||
|     @testing.expectedFailureSerDerPreDispatch  # linear shouldn't decompose | ||||
|     def test_export_decomps_simple(self): | ||||
|         class M(torch.nn.Module): | ||||
|             def __init__(self): | ||||
| @ -3052,6 +3058,7 @@ def forward(self, arg_0): | ||||
|         self.assertEqual(ep.module()(*inputs), m(*inputs)) | ||||
|  | ||||
|     @testing.expectedFailureSerDer  # symfloat nyi | ||||
|     @testing.expectedFailureSerDerPreDispatch  # symfloat nyi | ||||
|     @testing.expectedFailureRetraceability | ||||
|     def test_sym_sqrt(self): | ||||
|         import math | ||||
|  | ||||
| @ -9,6 +9,7 @@ except ImportError: | ||||
|     import testing | ||||
|  | ||||
| from torch.export import export, load, save | ||||
| from torch.export._trace import _export | ||||
|  | ||||
| test_classes = {} | ||||
|  | ||||
| @ -22,10 +23,21 @@ def mocked_serder_export(*args, **kwargs): | ||||
|     return loaded_ep | ||||
|  | ||||
|  | ||||
| def mocked_serder_export_pre_dispatch(*args, **kwargs): | ||||
|     ep = _export(*args, **kwargs, pre_dispatch=True) | ||||
|     buffer = io.BytesIO() | ||||
|     save(ep, buffer) | ||||
|     buffer.seek(0) | ||||
|     loaded_ep = load(buffer) | ||||
|     return loaded_ep | ||||
|  | ||||
|  | ||||
| def make_dynamic_cls(cls): | ||||
|     suffix = "_serdes" | ||||
|     suffix_pre_dispatch = "_serdes_pre_dispatch" | ||||
|  | ||||
|     cls_prefix = "SerDesExport" | ||||
|     cls_prefix_pre_dispatch = "SerDesExportPreDispatch" | ||||
|  | ||||
|     test_class = testing.make_test_cls_with_mocked_export( | ||||
|         cls, | ||||
| @ -35,11 +47,21 @@ def make_dynamic_cls(cls): | ||||
|         xfail_prop="_expected_failure_serdes", | ||||
|     ) | ||||
|  | ||||
|     test_class_pre_dispatch = testing.make_test_cls_with_mocked_export( | ||||
|         cls, | ||||
|         cls_prefix_pre_dispatch, | ||||
|         suffix_pre_dispatch, | ||||
|         mocked_serder_export_pre_dispatch, | ||||
|         xfail_prop="_expected_failure_serdes_pre_dispatch", | ||||
|     ) | ||||
|  | ||||
|     test_classes[test_class.__name__] = test_class | ||||
|     test_classes[test_class_pre_dispatch.__name__] = test_class_pre_dispatch | ||||
|     # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING | ||||
|     globals()[test_class.__name__] = test_class | ||||
|     globals()[test_class_pre_dispatch.__name__] = test_class_pre_dispatch | ||||
|     test_class.__module__ = __name__ | ||||
|     return test_class | ||||
|     test_class_pre_dispatch.__module__ = __name__ | ||||
|  | ||||
|  | ||||
| tests = [ | ||||
|  | ||||
| @ -58,3 +58,8 @@ def expectedFailureRetraceability(fn): | ||||
| def expectedFailureSerDer(fn): | ||||
|     fn._expected_failure_serdes = True | ||||
|     return fn | ||||
|  | ||||
|  | ||||
| def expectedFailureSerDerPreDispatch(fn): | ||||
|     fn._expected_failure_serdes_pre_dispatch = True | ||||
|     return fn | ||||
|  | ||||
| @ -1182,6 +1182,35 @@ class MutationTests(torch._dynamo.test_case.TestCase): | ||||
|             ["out_ptr"], | ||||
|         ) | ||||
|  | ||||
|     @make_mutation_test | ||||
|     def test_reduce_sum(): | ||||
|         @triton.jit | ||||
|         def reduce_sum_kernel(a_ptr, c_ptr, stride_am, stride_an): | ||||
|             offs_am = tl.arange(0, 4) | ||||
|             offs_an = tl.arange(0, 4) | ||||
|             a_ptrs = a_ptr + ( | ||||
|                 offs_am[:, None] * stride_am + offs_an[None, :] * stride_an | ||||
|             ) | ||||
|             a = tl.load(a_ptrs) | ||||
|             m = tl.sum(a, axis=1) | ||||
|             tl.store(c_ptr + tl.arange(0, 4), m) | ||||
|  | ||||
|         return ( | ||||
|             reduce_sum_kernel, | ||||
|             { | ||||
|                 "a_ptr": torch.randn(4, 4), | ||||
|                 "c_ptr": torch.randn(4), | ||||
|                 "stride_am": 4, | ||||
|                 "stride_an": 4, | ||||
|             }, | ||||
|             # TODO(aakhundov): tt.reduce is now supported, but only | ||||
|             # in the new MLIR-based Triton analysis pass (not in the | ||||
|             # old TTIR string parsing-based one). change the line | ||||
|             # below to ["c_ptr"] when new Triton pin lands and this | ||||
|             # test starts failing. | ||||
|             ["a_ptr", "c_ptr"], | ||||
|         ) | ||||
|  | ||||
|     @make_mutation_test | ||||
|     def test_argmax(): | ||||
|         @triton.jit | ||||
| @ -1204,7 +1233,11 @@ class MutationTests(torch._dynamo.test_case.TestCase): | ||||
|                 "stride_am": 4, | ||||
|                 "stride_an": 4, | ||||
|             }, | ||||
|             # TODO(oulgen): tt.reduce closures are not implemented yet | ||||
|             # TODO(aakhundov): tt.reduce is now supported, but only | ||||
|             # in the new MLIR-based Triton analysis pass (not in the | ||||
|             # old TTIR string parsing-based one). change the line | ||||
|             # below to ["c_ptr"] when new Triton pin lands and this | ||||
|             # test starts failing. | ||||
|             ["a_ptr", "c_ptr"], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -66,6 +66,8 @@ CUSPARSE_SPMM_COMPLEX128_SUPPORTED = ( | ||||
|     IS_WINDOWS and torch.version.cuda and version.parse(torch.version.cuda) > version.parse("11.2") | ||||
| ) or (not IS_WINDOWS and not TEST_WITH_ROCM) | ||||
|  | ||||
| HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0") | ||||
|  | ||||
| def all_sparse_layouts(test_name='layout', include_strided=False): | ||||
|     return parametrize(test_name, [ | ||||
|         subtest(torch.strided, name='Strided'), | ||||
|  | ||||
| @ -21,7 +21,7 @@ from torch.testing._internal.common_dtype import ( | ||||
|     floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and, | ||||
|     all_types_and_complex, floating_and_complex_types_and) | ||||
| from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse | ||||
| from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED | ||||
| from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED | ||||
| import operator | ||||
|  | ||||
| if TEST_SCIPY: | ||||
| @ -2024,7 +2024,9 @@ class TestSparseCSR(TestCase): | ||||
|     @dtypesIfCUDA(*floating_types_and(torch.complex64, | ||||
|                                       *[torch.bfloat16] if SM80OrLater else [], | ||||
|                                       *[torch.half] if SM53OrLater else [], | ||||
|                                       *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) | ||||
|                                       *[torch.complex128] | ||||
|                                       if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED | ||||
|                                       else [])) | ||||
|     @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, | ||||
|                         torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) | ||||
|     def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k): | ||||
|  | ||||
| @ -28,13 +28,76 @@ except ImportError: | ||||
|  | ||||
| CPP_H_NAME = "spv.h" | ||||
| CPP_SRC_NAME = "spv.cpp" | ||||
| DEFAULT_ENV = { | ||||
|  | ||||
| DEFAULT_ENV: Dict[str, Any] = { | ||||
|     "PRECISION": "highp", | ||||
|     "FLOAT_IMAGE_FORMAT": "rgba16f", | ||||
|     "INT_IMAGE_FORMAT": "rgba32i", | ||||
|     "UINT_IMAGE_FORMAT": "rgba32ui", | ||||
| } | ||||
|  | ||||
| TYPES_ENV: Dict[str, Any] = { | ||||
|     "IMAGE_FORMAT": { | ||||
|         "float": "rgba32f", | ||||
|         "half": "rgba16f", | ||||
|         "int": "rgba32i", | ||||
|         "uint": "rgba32ui", | ||||
|         "int8": "rgba8i", | ||||
|         "uint8": "rgba8ui", | ||||
|     }, | ||||
|     "IMAGE_T": { | ||||
|         3: { | ||||
|             "float": "image3D", | ||||
|             "half": "image3D", | ||||
|             "int": "iimage3D", | ||||
|             "uint": "uimage3D", | ||||
|         }, | ||||
|         2: { | ||||
|             "float": "image2D", | ||||
|             "half": "image2D", | ||||
|             "int": "iimage2D", | ||||
|             "uint": "uimage2D", | ||||
|         }, | ||||
|     }, | ||||
|     "SAMPLER_T": { | ||||
|         3: { | ||||
|             "float": "sampler3D", | ||||
|             "half": "sampler3D", | ||||
|             "int": "isampler3D", | ||||
|             "uint": "usampler3D", | ||||
|         }, | ||||
|         2: { | ||||
|             "float": "sampler2D", | ||||
|             "half": "sampler2D", | ||||
|             "int": "isampler2D", | ||||
|             "uint": "usampler2D", | ||||
|         }, | ||||
|     }, | ||||
|     "VEC4_T": { | ||||
|         "float": "vec4", | ||||
|         "half": "vec4", | ||||
|         "int": "ivec4", | ||||
|         "uint": "uvec4", | ||||
|         "int8": "vec4", | ||||
|         "uint8": "uvec4", | ||||
|     }, | ||||
|     "T": { | ||||
|         "float": "float", | ||||
|         "half": "float", | ||||
|         "int": "int", | ||||
|         "uint": "uint", | ||||
|         "int8": "int", | ||||
|         "uint8": "uint8", | ||||
|     }, | ||||
| } | ||||
|  | ||||
| FUNCS_ENV: Dict[str, Any] = { | ||||
|     "GET_POS": { | ||||
|         3: lambda pos: pos, | ||||
|         2: lambda pos: f"{pos}.xy", | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| def extract_filename(path: str, keep_ext: bool = True) -> Any: | ||||
|     if keep_ext: | ||||
| @ -671,7 +734,10 @@ def main(argv: List[str]) -> int: | ||||
|     ) | ||||
|     options = parser.parse_args() | ||||
|  | ||||
|     DEFAULT_ENV.update(TYPES_ENV) | ||||
|     DEFAULT_ENV.update(FUNCS_ENV) | ||||
|     env = DEFAULT_ENV | ||||
|  | ||||
|     for key, value in parse_arg_env(options.env).items(): | ||||
|         env[key] = value | ||||
|  | ||||
|  | ||||
| @ -113,6 +113,39 @@ graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") | ||||
| trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") | ||||
|  | ||||
|  | ||||
| @dataclass(frozen=True) | ||||
| class VariableTrackerCacheKey: | ||||
|     vt_id: int | ||||
|     # Two different source can point to the same object. However, Dynamo handles | ||||
|     # globals and local source differently when it comes to guards and possibly | ||||
|     # some other parts as well. So, cache also relies on the source. | ||||
|     source: Source | ||||
|  | ||||
|  | ||||
| class VariableTrackerCache: | ||||
|     def __init__(self): | ||||
|         self.cache = {} | ||||
|  | ||||
|     def lookup(self, value, source): | ||||
|         key = VariableTrackerCacheKey(id(value), source) | ||||
|         if key not in self.cache: | ||||
|             return None | ||||
|         return self.cache[key] | ||||
|  | ||||
|     def add(self, value, source, vt): | ||||
|         key = VariableTrackerCacheKey(id(value), source) | ||||
|         self.cache[key] = vt | ||||
|  | ||||
|     def clone(self): | ||||
|         # Needed for copy and restore graph state | ||||
|         new_cache = VariableTrackerCache() | ||||
|         new_cache.cache.update(self.cache) | ||||
|         return new_cache | ||||
|  | ||||
|     def clear(self): | ||||
|         self.cache.clear() | ||||
|  | ||||
|  | ||||
| class OutputGraphState(NamedTuple): | ||||
|     input_source_to_var: Dict[Source, VariableTracker] | ||||
|     tracked_fakes: List[TrackedFake] | ||||
| @ -122,6 +155,7 @@ class OutputGraphState(NamedTuple): | ||||
|     global_state: Optional[Dict[str, bool]] | ||||
|     param_name_to_source: Optional[Dict[str, Source]] | ||||
|     side_effects: SideEffects | ||||
|     variable_tracker_cache: VariableTrackerCache | ||||
|     timestamp: int | ||||
|     non_compliant_ops: Set[torch._ops.OpOverload] | ||||
|     compliant_custom_ops: Set[torch._ops.OpOverload] | ||||
| @ -320,6 +354,9 @@ class OutputGraph(Checkpointable[OutputGraphState]): | ||||
|         # Stores the full fqn of a param or buffer to the relevant source. | ||||
|         self.param_name_to_source: Optional[Dict[str, Source]] = dict() | ||||
|         self.side_effects = SideEffects() | ||||
|         # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL | ||||
|         # and LOAD_ATTR for same python objects free. | ||||
|         self.variable_tracker_cache = VariableTrackerCache() | ||||
|         self.code_options = dict(code_options) | ||||
|         self.output_instructions: List[Instruction] = [] | ||||
|         # used to track nodes that are added between calls of copy_graphstate | ||||
| @ -592,6 +629,7 @@ class OutputGraph(Checkpointable[OutputGraphState]): | ||||
|             global_state, | ||||
|             dict(self.param_name_to_source), | ||||
|             self.side_effects.clone(), | ||||
|             self.variable_tracker_cache.clone(), | ||||
|             self.timestamp, | ||||
|             set(self.non_compliant_ops), | ||||
|             set(self.compliant_custom_ops), | ||||
| @ -610,6 +648,7 @@ class OutputGraph(Checkpointable[OutputGraphState]): | ||||
|             global_state, | ||||
|             self.param_name_to_source, | ||||
|             self.side_effects, | ||||
|             self.variable_tracker_cache, | ||||
|             self.timestamp, | ||||
|             self.non_compliant_ops, | ||||
|             self.compliant_custom_ops, | ||||
| @ -1571,6 +1610,7 @@ class OutputGraph(Checkpointable[OutputGraphState]): | ||||
|         self.real_value_cache.clear() | ||||
|         self.input_name_to_proxy.clear() | ||||
|         self.side_effects.clear() | ||||
|         self.variable_tracker_cache.clear() | ||||
|         self.register_finalizer_fns.clear() | ||||
|         self.dynamo_flat_name_to_original_fqn.clear() | ||||
|         self.tracing_context.clear() | ||||
|  | ||||
| @ -267,10 +267,17 @@ class VariableBuilder: | ||||
|             if dup_guard: | ||||
|                 self.install_guards(dup_guard) | ||||
|             return side_effect_result | ||||
|  | ||||
|         cached_vt = self.tx.output.variable_tracker_cache.lookup(value, self.source) | ||||
|         if cached_vt: | ||||
|             return cached_vt | ||||
|  | ||||
|         vt = self._wrap(value) | ||||
|         vt.source = self.source | ||||
|         if self._can_lift_attrs_to_inputs(vt): | ||||
|             vt = self.tx.output.side_effects.track_object_existing(value, vt) | ||||
|  | ||||
|         self.tx.output.variable_tracker_cache.add(value, self.source, vt) | ||||
|         return vt | ||||
|  | ||||
|     def _can_lift_attrs_to_inputs(self, vt): | ||||
|  | ||||
| @ -248,29 +248,50 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: | ||||
|             fn_name = op.get_str_attr("sym_name") | ||||
|             functions[fn_name] = fn_ops | ||||
|         elif child_block_ids: | ||||
|             if name in ("scf.if", "scf.for", "scf.while"): | ||||
|                 # for blocked control flow ops: inline the enclosed | ||||
|                 # ops into the parent block + rewire the last op in | ||||
|                 # each child block (yield) to return the scf result | ||||
|                 yield_ops = [] | ||||
|             if name in ("scf.if", "scf.for", "scf.while", "tt.reduce"): | ||||
|                 # for blocked ops: inline the enclosed ops into | ||||
|                 # the parent block + rewire the last op in each | ||||
|                 # child block to return the block result | ||||
|                 return_ops = [] | ||||
|                 for block_id in child_block_ids: | ||||
|                     # the block args used as operands of the ops in the block | ||||
|                     # (and nested blocks inlined in the current block by now) | ||||
|                     # are replaced by new fake Intermediates to avoid "this | ||||
|                     # operand is not returned by anything other op in the fn" | ||||
|                     # error in the downstream analysis | ||||
|                     for idx in block_id_to_block_arg_ids[block_id]: | ||||
|                         next_fake_intermediate -= 1 | ||||
|                         replacements[idx] = Intermediate(next_fake_intermediate) | ||||
|                     if name.startswith("scf."): | ||||
|                         # the scf block args are ignored by the pass. but, as they | ||||
|                         # may be used as operands of the ops inside the block | ||||
|                         # (and nested blocks inlined in the current block by now), | ||||
|                         # they are replaced by new fake Intermediates to avoid "this | ||||
|                         # operand is not returned by any other op in the fn" error | ||||
|                         # in the downstream analysis | ||||
|                         for idx in block_id_to_block_arg_ids[block_id]: | ||||
|                             next_fake_intermediate -= 1 | ||||
|                             replacements[idx] = Intermediate(next_fake_intermediate) | ||||
|                     else: | ||||
|                         # for tt.reduce, wire the block arguments to the op arguments | ||||
|                         num_operands = len(operand_ids) | ||||
|                         block_arg_ids = block_id_to_block_arg_ids[block_id] | ||||
|                         assert len(block_arg_ids) == 2 * num_operands, ( | ||||
|                             "tt.reduce is expected to have twice as " | ||||
|                             "many block arguments as op arguments: " | ||||
|                             f"{operand_ids=}, {block_arg_ids=}." | ||||
|                         ) | ||||
|                         for i, idx in enumerate(block_arg_ids): | ||||
|                             # for a tt.reduce op with N arguments, the block | ||||
|                             # arguments comprise N reduced values followed by | ||||
|                             # N current values corresponding to the N op args | ||||
|                             replacements[idx] = Intermediate( | ||||
|                                 operand_ids[i % num_operands] | ||||
|                             ) | ||||
|  | ||||
|                     if block_id in op_stack: | ||||
|                         block_ops = op_stack.pop(block_id) | ||||
|                         if not block_ops: | ||||
|                             continue | ||||
|                         last_ret, last_ops = block_ops.popitem() | ||||
|                         if all(op.name == "scf.yield" for op in last_ops): | ||||
|                             # if last_ops are scf.yield, treat them separately | ||||
|                             yield_ops.extend(last_ops) | ||||
|                         if all( | ||||
|                             op.name in ("scf.yield", "tt.reduce.return") | ||||
|                             for op in last_ops | ||||
|                         ): | ||||
|                             # if last_ops are all return ops, treat them separately | ||||
|                             return_ops.extend(last_ops) | ||||
|                         else: | ||||
|                             # otherwise, return last_ops to the block | ||||
|                             block_ops[last_ret] = last_ops | ||||
| @ -279,10 +300,9 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: | ||||
|  | ||||
|                 scf_results = [Intermediate(idx) for idx in result_ids] | ||||
|                 for scf_result in scf_results: | ||||
|                     for yield_op in yield_ops: | ||||
|                         op_stack[parent_block_id][scf_result].append(yield_op) | ||||
|                     for return_op in return_ops: | ||||
|                         op_stack[parent_block_id][scf_result].append(return_op) | ||||
|             else: | ||||
|                 # TODO(oulgen): add support for tt.reduce | ||||
|                 raise Exception( | ||||
|                     f"Unknown blocked function: {name}. Can't capture the TTIR." | ||||
|                 ) | ||||
|  | ||||
| @ -2002,7 +2002,7 @@ class CppCodeCache: | ||||
|                         cpp_compile_command( | ||||
|                             input=input_path, | ||||
|                             output=output_path, | ||||
|                             vec_isa=cls.vec_isa, | ||||
|                             vec_isa=picked_vec_isa, | ||||
|                             **cls.cpp_compile_command_flags, | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,9 +1,9 @@ | ||||
| import torch | ||||
| from torch.library import Library, impl | ||||
| from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax | ||||
| from typing import Tuple | ||||
| from torch._refs import _unsqueeze_multiple | ||||
| from typing import Optional, Tuple | ||||
|  | ||||
| import torch | ||||
| from torch._refs import _unsqueeze_multiple | ||||
| from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax | ||||
| from torch.library import impl, Library | ||||
|  | ||||
| # Note: decomposed means decomposed quantized tensor, using decomposed so that the | ||||
| # name is not too long | ||||
| @ -13,7 +13,7 @@ _DTYPE_TO_QVALUE_BOUNDS = { | ||||
|     torch.uint8: (0, 255), | ||||
|     torch.int8: (-128, 127), | ||||
|     torch.int16: (-(2**15), 2**15 - 1), | ||||
|     torch.int32: (-(2**31), 2**31 - 1) | ||||
|     torch.int32: (-(2**31), 2**31 - 1), | ||||
| } | ||||
|  | ||||
| # Helper to check the passed in quant min and max are valid for the dtype | ||||
| @ -60,13 +60,26 @@ def quantize_per_tensor( | ||||
|     """ | ||||
|     if input.dtype == torch.bfloat16: | ||||
|         input = input.to(torch.float32) | ||||
|  | ||||
|     assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" | ||||
|     _quant_min_max_bounds_check(quant_min, quant_max, dtype) | ||||
|  | ||||
|     inv_scale = 1.0 / scale | ||||
|     return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta") | ||||
| def quantize_per_tensor_meta( | ||||
|         input: torch.Tensor, | ||||
|         scale: float, | ||||
|         zero_point: int, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
| ) -> torch.Tensor: | ||||
|     if input.dtype == torch.bfloat16: | ||||
|         input = input.to(torch.float32) | ||||
|     assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" | ||||
|     return torch.empty_like(input, dtype=dtype) | ||||
|  | ||||
| quantized_decomposed_lib.define( | ||||
|     "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " | ||||
|     "int quant_min, int quant_max, ScalarType dtype) -> Tensor") | ||||
| @ -90,7 +103,14 @@ def quantize_per_tensor_tensor( | ||||
|     return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") | ||||
| def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): | ||||
| def quantize_per_tensor_tensor_meta( | ||||
|         input: torch.Tensor, | ||||
|         scale: torch.Tensor, | ||||
|         zero_point: torch.Tensor, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
| ) -> torch.Tensor: | ||||
|     if input.dtype == torch.bfloat16: | ||||
|         input = input.to(torch.float32) | ||||
|     assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" | ||||
| @ -122,7 +142,14 @@ def quantize_per_tensor_tensor2( | ||||
|     return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype) | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta") | ||||
| def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype): | ||||
| def quantize_per_tensor_tensor2_meta( | ||||
|         input: torch.Tensor, | ||||
|         scale: torch.Tensor, | ||||
|         zero_point: torch.Tensor, | ||||
|         quant_min: torch.Tensor, | ||||
|         quant_max: torch.Tensor, | ||||
|         dtype: torch.dtype | ||||
| ) -> torch.Tensor: | ||||
|     return quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype) | ||||
|  | ||||
| # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in | ||||
| @ -131,7 +158,7 @@ def quantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_ | ||||
| # We will revisit this later if we found there are no use cases for it | ||||
| quantized_decomposed_lib.define( | ||||
|     "dequantize_per_tensor(Tensor input, float scale, int zero_point, " | ||||
|     "int quant_min, int quant_max, ScalarType dtype) -> Tensor") | ||||
|     "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor") | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd") | ||||
| def dequantize_per_tensor( | ||||
| @ -140,7 +167,9 @@ def dequantize_per_tensor( | ||||
|         zero_point: int, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     """ Affine dequantization for the Tensor using the same quantization parameters to map | ||||
|     from quantized values to floating point values | ||||
| @ -163,22 +192,40 @@ def dequantize_per_tensor( | ||||
|        dtype (torch.dtype): dtype for input Tensor (not used in computation, | ||||
|        reserved for pattern matching) | ||||
|  | ||||
|        out_dtype (torch.dtype?): optional dtype for output Tensor | ||||
|  | ||||
|     Returns: | ||||
|        dequantized float32 Tensor | ||||
|     """ | ||||
|     assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}, but got {input.dtype}" | ||||
|     if out_dtype is None: | ||||
|         out_dtype = torch.float32 | ||||
|     if dtype in _DTYPE_TO_QVALUE_BOUNDS: | ||||
|         # TODO: investigate why | ||||
|         # (input - zero_point).to(torch.float32) * scale | ||||
|         # failed the test | ||||
|         return (input.to(torch.float32) - zero_point) * scale | ||||
|         return (input.to(out_dtype) - zero_point) * scale | ||||
|     else: | ||||
|         raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta") | ||||
| def dequantize_per_tensor_meta( | ||||
|     input: torch.Tensor, | ||||
|     scale: torch.Tensor, | ||||
|     zero_pointe: torch.Tensor, | ||||
|     quant_min: int, | ||||
|     quant_max: int, | ||||
|     dtype: torch.dtype, | ||||
|     *, | ||||
|     out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     if out_dtype is None: | ||||
|         out_dtype = torch.float32 | ||||
|     return torch.empty_like(input, dtype=out_dtype) | ||||
|  | ||||
| quantized_decomposed_lib.define( | ||||
|     "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " | ||||
|     "int quant_min, int quant_max, ScalarType dtype) -> Tensor") | ||||
|     "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor") | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd") | ||||
| def dequantize_per_tensor_tensor( | ||||
| @ -187,7 +234,9 @@ def dequantize_per_tensor_tensor( | ||||
|         zero_point: torch.Tensor, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     """ Affine dequantization for the Tensor using the same quantization parameters to map | ||||
|     from quantized values to floating point values | ||||
| @ -196,22 +245,33 @@ def dequantize_per_tensor_tensor( | ||||
|     """ | ||||
|     assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" | ||||
|     assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" | ||||
|     return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) | ||||
|     return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype, out_dtype=out_dtype) | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") | ||||
| def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): | ||||
| def dequantize_per_tensor_tensor_meta( | ||||
|         input: torch.Tensor, | ||||
|         scale: torch.Tensor, | ||||
|         zero_point: torch.Tensor, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     if out_dtype is None: | ||||
|         out_dtype = torch.float32 | ||||
|     assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" | ||||
|     assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" | ||||
|     assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" | ||||
|     if dtype in _DTYPE_TO_QVALUE_BOUNDS: | ||||
|         return torch.empty_like(input, dtype=torch.float32) | ||||
|         return torch.empty_like(input, dtype=out_dtype) | ||||
|     else: | ||||
|         raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") | ||||
|  | ||||
| # TODO: remove other variants and keep this one | ||||
| quantized_decomposed_lib.define( | ||||
|     "dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, " | ||||
|     "Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor") | ||||
|     "Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor") | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "CompositeExplicitAutograd") | ||||
| def dequantize_per_tensor_tensor2( | ||||
| @ -220,7 +280,9 @@ def dequantize_per_tensor_tensor2( | ||||
|         zero_point: torch.Tensor, | ||||
|         quant_min: torch.Tensor, | ||||
|         quant_max: torch.Tensor, | ||||
|         dtype: torch.dtype | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     """ Affine dequantization for the Tensor using the same quantization parameters to map | ||||
|     from quantized values to floating point values | ||||
| @ -229,11 +291,21 @@ def dequantize_per_tensor_tensor2( | ||||
|     """ | ||||
|     assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" | ||||
|     assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}" | ||||
|     return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype) | ||||
|     return dequantize_per_tensor( | ||||
|         input, scale.item(), zero_point.item(), quant_min.item(), quant_max.item(), dtype, out_dtype=out_dtype) | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta") | ||||
| def dequantize_per_tensor_tensor2_meta(input, scale, zero_point, quant_min, quant_max, dtype): | ||||
|     return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype) | ||||
| def dequantize_per_tensor_tensor2_meta( | ||||
|         input, | ||||
|         scale, | ||||
|         zero_point, | ||||
|         quant_min, | ||||
|         quant_max, | ||||
|         dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     return dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype) | ||||
|  | ||||
| quantized_decomposed_lib.define( | ||||
|     "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " | ||||
| @ -415,7 +487,7 @@ def quantize_per_channel_meta( | ||||
| # We will revisit this later if we found there are no use cases for it | ||||
| quantized_decomposed_lib.define( | ||||
|     "dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " | ||||
|     "int quant_min, int quant_max, ScalarType dtype) -> Tensor") | ||||
|     "int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor") | ||||
|  | ||||
| @impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd") | ||||
| def dequantize_per_channel( | ||||
| @ -425,7 +497,9 @@ def dequantize_per_channel( | ||||
|         axis: int, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     """ Affine per channel dequantization for the Tensor using the same quantization | ||||
|     parameters for each channel/axis to map from quantized values to floating point values | ||||
| @ -450,20 +524,24 @@ def dequantize_per_channel( | ||||
|        dtype (torch.dtype): requested dtype for output Tensor (not used in computation, | ||||
|        reserved for pattern matching) | ||||
|  | ||||
|        out_dtype (torch.dtype?): optional dtype for output Tensor | ||||
|  | ||||
|     Returns: | ||||
|        dequantized float32 Tensor | ||||
|     """ | ||||
|     assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" | ||||
|     if out_dtype is None: | ||||
|         out_dtype = torch.float32 | ||||
|     assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" | ||||
|     _quant_min_max_bounds_check(quant_min, quant_max, dtype) | ||||
|     input, permute_axis_list = _permute_to_axis_zero(input, axis) | ||||
|     res = torch.zeros_like(input, dtype=torch.float32) | ||||
|     res = torch.zeros_like(input, dtype=out_dtype) | ||||
|  | ||||
|     for i in range(input.size(0)): | ||||
|         # TODO: investigate why | ||||
|         # (input[i] - zero_points[i]).to(torch.float32) * scales[i] | ||||
|         # (input[i] - zero_points[i]).to(out_dtype) * scales[i] | ||||
|         # failed the test | ||||
|         res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i] | ||||
|         res[i] = (input[i].to(out_dtype) - zero_points[i]) * scales[i] | ||||
|  | ||||
|     out = res.permute(tuple(permute_axis_list)) | ||||
|     return out | ||||
| @ -476,12 +554,16 @@ def dequantize_per_channel_meta( | ||||
|         axis: int, | ||||
|         quant_min: int, | ||||
|         quant_max: int, | ||||
|         dtype: torch.dtype | ||||
|         dtype: torch.dtype, | ||||
|         *, | ||||
|         out_dtype: Optional[torch.dtype] = None | ||||
| ) -> torch.Tensor: | ||||
|     assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" | ||||
|     if out_dtype is None: | ||||
|         out_dtype = torch.float32 | ||||
|     assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" | ||||
|     _quant_min_max_bounds_check(quant_min, quant_max, dtype) | ||||
|     return torch.empty_like(input, dtype=torch.float32) | ||||
|     return torch.empty_like(input, dtype=out_dtype) | ||||
|  | ||||
| quantized_decomposed_lib.define( | ||||
|     "fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, " | ||||
|  | ||||
| @ -2918,7 +2918,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse( | ||||
|   // If the nccl branch is not "exp" then we just error | ||||
|   C10_THROW_ERROR( | ||||
|       Error, | ||||
|       "allreduce_sparse is only available in the NCCL experimental branch."); | ||||
|       "NCCL does not support all_reduce with sparse tensors. Please use dense tensors instead."); | ||||
| #endif | ||||
| } | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user