mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fixes #160752 # Background: `torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing. # Changes: - functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation: - Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched. - one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing. The minimal repro from the issue is now fixed: ```python import torch import torch.nn.functional as F MAX, BATCH = 3, 37 def func(x, idxs): return x.square() * F.one_hot(idxs, MAX) def jacfunc(x, idxs): return torch.func.jacfwd(func, argnums=0)(x, idxs) idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64) x = torch.rand((BATCH, MAX), dtype=torch.float64) # eager out_eager = jacfunc(x, idxs) # compiled dynamic jacfunc_c = torch.compile(jacfunc, dynamic=True) out_comp = jacfunc_c(x, idxs) torch.testing.assert_close(out_eager, out_comp) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519