mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	propagate XLA's metadata after functional sync (#131076)
Fixes https://github.com/pytorch/xla/issues/7174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131076 Approved by: https://github.com/bdhirsh
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							7eb2a99585
						
					
				
				
					commit
					b40249b462
				
			| @ -660,6 +660,21 @@ void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef o | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) { | ||||||
|  |   if (tensor.key_set().has(c10::DispatchKey::XLA)) { | ||||||
|  |     at::_propagate_xla_data(tensor, other); | ||||||
|  |   } | ||||||
|  |  } | ||||||
|  |  | ||||||
|  | void propagate_xla_data_direct(const ITensorListRef tensor, | ||||||
|  |                                ITensorListRef other) { | ||||||
|  |   auto tensor_it = tensor.begin(); | ||||||
|  |   auto other_it = other.begin(); | ||||||
|  |   for (C10_UNUSED const auto i : c10::irange(tensor.size())) { | ||||||
|  |     propagate_xla_data_direct(*tensor_it++, *other_it++); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| void commit_update(const Tensor& functional_tensor) { | void commit_update(const Tensor& functional_tensor) { | ||||||
|   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); |   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor)); | ||||||
|   unsafeGetFunctionalWrapper(functional_tensor)->commit_update(); |   unsafeGetFunctionalWrapper(functional_tensor)->commit_update(); | ||||||
|  | |||||||
| @ -343,6 +343,13 @@ TORCH_API void propagate_xla_data( | |||||||
|     const ITensorListRef functional_tensor, |     const ITensorListRef functional_tensor, | ||||||
|     ITensorListRef other); |     ITensorListRef other); | ||||||
|  |  | ||||||
|  | TORCH_API void propagate_xla_data_direct( | ||||||
|  |     const Tensor& tensor, | ||||||
|  |     const Tensor& other); | ||||||
|  | TORCH_API void propagate_xla_data_direct( | ||||||
|  |     const ITensorListRef tensor, | ||||||
|  |     ITensorListRef other); | ||||||
|  |  | ||||||
| Tensor create_functional_tensor_with_view_meta( | Tensor create_functional_tensor_with_view_meta( | ||||||
|     const Tensor& view_to_wrap, |     const Tensor& view_to_wrap, | ||||||
|     const Tensor& base, |     const Tensor& base, | ||||||
|  | |||||||
| @ -590,10 +590,12 @@ def wrap_propagate_mutations_and_return( | |||||||
|     ): |     ): | ||||||
|         updates.append( |         updates.append( | ||||||
|             f"""\ |             f"""\ | ||||||
|   at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret}); |   auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg}); | ||||||
|   at::functionalization::impl::replace_({outer_arg}, {inner_ret}); |   at::functionalization::impl::replace_({outer_arg}, {inner_ret}); | ||||||
|   at::functionalization::impl::commit_update({outer_arg}); |   at::functionalization::impl::commit_update({outer_arg}); | ||||||
|   at::functionalization::impl::sync({outer_arg});""" |   at::functionalization::impl::sync({outer_arg}); | ||||||
|  |   auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg}); | ||||||
|  |   at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);""" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # Finally, we return: |     # Finally, we return: | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user