Support NJT chunk() backward on batch dim (#144584)

Part of my BE project addressing NJT bugs surfaced via OpInfo tests.

Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584
Approved by: https://github.com/soulitzer
ghstack dependencies: #144582, #144583
This commit is contained in:
Joel Schlosser
2025-01-17 16:53:45 -05:00
committed by PyTorch MergeBot
parent 8a57234033
commit 3ee531f8b9
4 changed files with 25 additions and 32 deletions

View File

@ -652,9 +652,20 @@ def copy_default(func, *args, **kwargs):
inp = new_kwargs.pop("input")
src = new_kwargs.pop("src")
if inp._size != src._size:
raise RuntimeError(
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor."
)
# try to recursively copy_ on unbound components to get around nested int mismatch
# TODO: eventually do a direct copy when this is possible
inp_comps = inp.unbind()
inp_comp_shapes = [c.shape for c in inp_comps]
src_comps = src.unbind()
src_comp_shapes = [c.shape for c in src_comps]
if inp_comp_shapes != src_comp_shapes:
raise RuntimeError(
"copy_(): expected compatible input and src shapes, but got: "
f"{inp.shape} and {src.shape}"
)
for inp_comp, src_comp in zip(inp_comps, src_comps):
inp_comp.copy_(src_comp)
# AOTD allows mutations of inputs only, (not views of the inputs).
# NJT.values() returns _values.detach() to workaround some issues.
# To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).