feat(dynamo): IS#160752 make F.one_hot work with jacfwd + torch.compile(dynamic=True) (#160837)

Fixes #160752

# Background:
`torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing.

# Changes:
- functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation:
- Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched.
- one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing.

The minimal repro from the issue is now fixed:
```python
import torch
import torch.nn.functional as F

MAX, BATCH = 3, 37

def func(x, idxs):
    return x.square() * F.one_hot(idxs, MAX)

def jacfunc(x, idxs):
    return torch.func.jacfwd(func, argnums=0)(x, idxs)

idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64)
x = torch.rand((BATCH, MAX), dtype=torch.float64)

# eager
out_eager = jacfunc(x, idxs)

# compiled dynamic
jacfunc_c = torch.compile(jacfunc, dynamic=True)
out_comp = jacfunc_c(x, idxs)

torch.testing.assert_close(out_eager, out_comp)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837
Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
This commit is contained in:
Michael Gathara
2025-10-15 02:48:41 +00:00
committed by PyTorch MergeBot
parent 4f400ab520
commit b4fd47179e
3 changed files with 57 additions and 34 deletions

View File

@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
}
// TODO: replace with targetable functionalization
// uses functional formulation for one_hot under vmap to be compatible with
// fakeTensor/dynamic shapes and compiled functorch transforms.
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
// but requires explicit positive num_classes under vmap to avoid
// data-dependent output shapes.
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
auto shape = self.sym_sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.sym_numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
}
}
// disallow implicit inference under vmap; this would be data-dependent
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
"provide an explicit positive num_classes argument.");
// Disabling all of the following checks. This is OK because scatter has checks too.
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
// // non-empty tensor
// if (self.device().type() != at::kCUDA) {
// //for cuda, rely on device assert thrown by scatter
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
// }
// if (self.device().type() != at::kCUDA) {
// //rely on device asserts from scatter to avoid sync here
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
// }
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
return ret.scatter(-1, self.unsqueeze(-1), 1);
const auto options = self.options();
at::Tensor index = at::arange(num_classes, options);
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
}
template <typename A, A a, typename C>

View File

@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
auto shape = self.sizes().vec();
auto shape = self.sym_sizes().vec();
// empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.numel() == 0) {
if (self.sym_numel() == 0) {
if (num_classes <= 0) {
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
} else {
shape.push_back(num_classes);
return at::empty(shape, self.options());
shape.emplace_back(num_classes);
return at::empty_symint(shape, self.options());
}
}
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
}
}
shape.push_back(num_classes);
Tensor ret = at::zeros(shape, self.options());
shape.emplace_back(num_classes);
Tensor ret = at::zeros_symint(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}

View File

@ -9192,6 +9192,47 @@ def ___make_guard_fn():
self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 2)
def test_jacfwd_one_hot_dynamic_compile(self):
import torch.nn.functional as F
MAX, BATCH = 3, 37
def func(x, idxs):
return x.square() * F.one_hot(idxs, MAX)
def jacfunc(x, idxs):
return torch.func.jacfwd(func, argnums=(0,))(x, idxs)
idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64)
x = torch.rand((BATCH, MAX), dtype=torch.float64)
eager = jacfunc(x, idxs)
compiled = torch.compile(jacfunc, backend="eager", dynamic=True)
out_comp = compiled(x, idxs)
self.assertEqual(eager[0], out_comp[0])
def test_tracing_nested_py_tree_mixed_all(self):
def fn(xs):
flat_xs, spec = python_pytree.tree_flatten(xs)
res = [x.clone() for x in flat_xs]
return python_pytree.tree_unflatten(res, spec)
xs = [torch.tensor(i) for i in range(3)]
xsa = (xs, xs)
xsb = {"aa": xsa, "ab": xs}
xsl = {
"a": xs,
"b": xsa,
"c": xsb,
}
counter = CompileCounter()
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
real_out = fn(xsl)
self.assertEqual(comp_out, real_out)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 18)
def test_any_all_symnode(self):
cnt = CompileCounter()