mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support NJT chunk() backward on batch dim (#144584)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. Implements `chunk()` backward on the batch dim, which was left out before. This PR unbinds the components and invokes `copy_()` on these to pass along the appropriate gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144584 Approved by: https://github.com/soulitzer ghstack dependencies: #144582, #144583
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a57234033
commit
3ee531f8b9
@ -652,9 +652,20 @@ def copy_default(func, *args, **kwargs):
|
||||
inp = new_kwargs.pop("input")
|
||||
src = new_kwargs.pop("src")
|
||||
if inp._size != src._size:
|
||||
raise RuntimeError(
|
||||
"copy_ only supports Nested Tensors that have same size and the exact same offset tensor."
|
||||
)
|
||||
# try to recursively copy_ on unbound components to get around nested int mismatch
|
||||
# TODO: eventually do a direct copy when this is possible
|
||||
inp_comps = inp.unbind()
|
||||
inp_comp_shapes = [c.shape for c in inp_comps]
|
||||
src_comps = src.unbind()
|
||||
src_comp_shapes = [c.shape for c in src_comps]
|
||||
if inp_comp_shapes != src_comp_shapes:
|
||||
raise RuntimeError(
|
||||
"copy_(): expected compatible input and src shapes, but got: "
|
||||
f"{inp.shape} and {src.shape}"
|
||||
)
|
||||
for inp_comp, src_comp in zip(inp_comps, src_comps):
|
||||
inp_comp.copy_(src_comp)
|
||||
|
||||
# AOTD allows mutations of inputs only, (not views of the inputs).
|
||||
# NJT.values() returns _values.detach() to workaround some issues.
|
||||
# To keep mutation in the graph, AOTD manually calls copy_ on the input (NJT).
|
||||
|
Reference in New Issue
Block a user