mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
propagate XLA's metadata after functional sync (#131076)
Fixes https://github.com/pytorch/xla/issues/7174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131076 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7eb2a99585
commit
b40249b462
@ -660,6 +660,21 @@ void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef o
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) {
|
||||||
|
if (tensor.key_set().has(c10::DispatchKey::XLA)) {
|
||||||
|
at::_propagate_xla_data(tensor, other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void propagate_xla_data_direct(const ITensorListRef tensor,
|
||||||
|
ITensorListRef other) {
|
||||||
|
auto tensor_it = tensor.begin();
|
||||||
|
auto other_it = other.begin();
|
||||||
|
for (C10_UNUSED const auto i : c10::irange(tensor.size())) {
|
||||||
|
propagate_xla_data_direct(*tensor_it++, *other_it++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void commit_update(const Tensor& functional_tensor) {
|
void commit_update(const Tensor& functional_tensor) {
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||||
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
||||||
|
|||||||
@ -343,6 +343,13 @@ TORCH_API void propagate_xla_data(
|
|||||||
const ITensorListRef functional_tensor,
|
const ITensorListRef functional_tensor,
|
||||||
ITensorListRef other);
|
ITensorListRef other);
|
||||||
|
|
||||||
|
TORCH_API void propagate_xla_data_direct(
|
||||||
|
const Tensor& tensor,
|
||||||
|
const Tensor& other);
|
||||||
|
TORCH_API void propagate_xla_data_direct(
|
||||||
|
const ITensorListRef tensor,
|
||||||
|
ITensorListRef other);
|
||||||
|
|
||||||
Tensor create_functional_tensor_with_view_meta(
|
Tensor create_functional_tensor_with_view_meta(
|
||||||
const Tensor& view_to_wrap,
|
const Tensor& view_to_wrap,
|
||||||
const Tensor& base,
|
const Tensor& base,
|
||||||
|
|||||||
@ -590,10 +590,12 @@ def wrap_propagate_mutations_and_return(
|
|||||||
):
|
):
|
||||||
updates.append(
|
updates.append(
|
||||||
f"""\
|
f"""\
|
||||||
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
|
auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
|
||||||
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
||||||
at::functionalization::impl::commit_update({outer_arg});
|
at::functionalization::impl::commit_update({outer_arg});
|
||||||
at::functionalization::impl::sync({outer_arg});"""
|
at::functionalization::impl::sync({outer_arg});
|
||||||
|
auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
|
||||||
|
at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finally, we return:
|
# Finally, we return:
|
||||||
|
|||||||
Reference in New Issue
Block a user