mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
propagate XLA's metadata after functional sync (#131076)
Fixes https://github.com/pytorch/xla/issues/7174 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131076 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
7eb2a99585
commit
b40249b462
@ -660,6 +660,21 @@ void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef o
|
||||
}
|
||||
}
|
||||
|
||||
void propagate_xla_data_direct(const Tensor& tensor, const Tensor& other) {
|
||||
if (tensor.key_set().has(c10::DispatchKey::XLA)) {
|
||||
at::_propagate_xla_data(tensor, other);
|
||||
}
|
||||
}
|
||||
|
||||
void propagate_xla_data_direct(const ITensorListRef tensor,
|
||||
ITensorListRef other) {
|
||||
auto tensor_it = tensor.begin();
|
||||
auto other_it = other.begin();
|
||||
for (C10_UNUSED const auto i : c10::irange(tensor.size())) {
|
||||
propagate_xla_data_direct(*tensor_it++, *other_it++);
|
||||
}
|
||||
}
|
||||
|
||||
void commit_update(const Tensor& functional_tensor) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(functional_tensor));
|
||||
unsafeGetFunctionalWrapper(functional_tensor)->commit_update();
|
||||
|
@ -343,6 +343,13 @@ TORCH_API void propagate_xla_data(
|
||||
const ITensorListRef functional_tensor,
|
||||
ITensorListRef other);
|
||||
|
||||
TORCH_API void propagate_xla_data_direct(
|
||||
const Tensor& tensor,
|
||||
const Tensor& other);
|
||||
TORCH_API void propagate_xla_data_direct(
|
||||
const ITensorListRef tensor,
|
||||
ITensorListRef other);
|
||||
|
||||
Tensor create_functional_tensor_with_view_meta(
|
||||
const Tensor& view_to_wrap,
|
||||
const Tensor& base,
|
||||
|
@ -590,10 +590,12 @@ def wrap_propagate_mutations_and_return(
|
||||
):
|
||||
updates.append(
|
||||
f"""\
|
||||
at::functionalization::impl::propagate_xla_data({outer_arg}, {inner_ret});
|
||||
auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
|
||||
at::functionalization::impl::replace_({outer_arg}, {inner_ret});
|
||||
at::functionalization::impl::commit_update({outer_arg});
|
||||
at::functionalization::impl::sync({outer_arg});"""
|
||||
at::functionalization::impl::sync({outer_arg});
|
||||
auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
|
||||
at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
|
||||
)
|
||||
|
||||
# Finally, we return:
|
||||
|
Reference in New Issue
Block a user