propagate XLA's metadata after functional sync (#131076)

Fixes https://github.com/pytorch/xla/issues/7174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131076
Approved by: https://github.com/bdhirsh
This commit is contained in:
JackCaoG
2024-07-31 18:20:00 +00:00
committed by PyTorch MergeBot
parent 7eb2a99585
commit b40249b462
3 changed files with 26 additions and 2 deletions

View File

@ -660,6 +660,21 @@ void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef o
}
}
void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) {
if (tensor.key_set().has(c10::DispatchKey::XLA)) {
at::_propagate_xla_data(tensor, other);
}
}
void propagate_xla_data_direct(const ITensorListRef tensor,
ITensorListRef other) {
auto tensor_it = tensor.begin();
auto other_it = other.begin();
for (C10_UNUSED const auto i : c10::irange(tensor.size())) {
propagate_xla_data_direct(*tensor_it++, *other_it++);
}
}
void commit_update(const Tensor& functional_tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();

View File

@ -343,6 +343,13 @@ TORCH_API void propagate_xla_data(
const ITensorListRef functional_tensor,
ITensorListRef other);
TORCH_API void propagate_xla_data_direct(
const Tensor& tensor,
const Tensor& other);
TORCH_API void propagate_xla_data_direct(
const ITensorListRef tensor,
ITensorListRef other);
Tensor create_functional_tensor_with_view_meta(
const Tensor& view_to_wrap,
const Tensor& base,

View File

@ -590,10 +590,12 @@ def wrap_propagate_mutations_and_return(
):
updates.append(
f"""\
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
at::functionalization::impl::commit_update({outer_arg});
at::functionalization::impl::sync({outer_arg});"""
at::functionalization::impl::sync({outer_arg});
auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
)
# Finally, we return: