mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
feat(dynamo): IS#160752 make F.one_hot work with jacfwd + torch.compile(dynamic=True) (#160837)
Fixes #160752 # Background: `torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing. # Changes: - functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation: - Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched. - one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing. The minimal repro from the issue is now fixed: ```python import torch import torch.nn.functional as F MAX, BATCH = 3, 37 def func(x, idxs): return x.square() * F.one_hot(idxs, MAX) def jacfunc(x, idxs): return torch.func.jacfwd(func, argnums=0)(x, idxs) idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64) x = torch.rand((BATCH, MAX), dtype=torch.float64) # eager out_eager = jacfunc(x, idxs) # compiled dynamic jacfunc_c = torch.compile(jacfunc, dynamic=True) out_comp = jacfunc_c(x, idxs) torch.testing.assert_close(out_eager, out_comp) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f400ab520
commit
b4fd47179e
@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
|
|||||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: replace with targetable functionalization
|
// uses functional formulation for one_hot under vmap to be compatible with
|
||||||
|
// fakeTensor/dynamic shapes and compiled functorch transforms.
|
||||||
|
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
|
||||||
|
// but requires explicit positive num_classes under vmap to avoid
|
||||||
|
// data-dependent output shapes.
|
||||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
||||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||||
auto shape = self.sym_sizes().vec();
|
|
||||||
|
|
||||||
// empty tensor could be converted to one hot representation,
|
|
||||||
// but shape inference is not possible.
|
|
||||||
if (self.sym_numel() == 0) {
|
|
||||||
if (num_classes <= 0) {
|
|
||||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
|
||||||
} else {
|
|
||||||
shape.emplace_back(num_classes);
|
|
||||||
return at::empty_symint(shape, self.options());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// disallow implicit inference under vmap; this would be data-dependent
|
||||||
|
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
|
||||||
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
||||||
"provide an explicit positive num_classes argument.");
|
"provide an explicit positive num_classes argument.");
|
||||||
|
|
||||||
// Disabling all of the following checks. This is OK because scatter has checks too.
|
const auto options = self.options();
|
||||||
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
|
at::Tensor index = at::arange(num_classes, options);
|
||||||
// // non-empty tensor
|
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
|
||||||
// if (self.device().type() != at::kCUDA) {
|
|
||||||
// //for cuda, rely on device assert thrown by scatter
|
|
||||||
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
|
||||||
// }
|
|
||||||
// if (self.device().type() != at::kCUDA) {
|
|
||||||
// //rely on device asserts from scatter to avoid sync here
|
|
||||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
|
||||||
// }
|
|
||||||
|
|
||||||
shape.emplace_back(num_classes);
|
|
||||||
Tensor ret = at::zeros_symint(shape, self.options());
|
|
||||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A, A a, typename C>
|
template <typename A, A a, typename C>
|
||||||
|
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shape = self.sizes().vec();
|
auto shape = self.sym_sizes().vec();
|
||||||
|
|
||||||
// empty tensor could be converted to one hot representation,
|
// empty tensor could be converted to one hot representation,
|
||||||
// but shape inference is not possible.
|
// but shape inference is not possible.
|
||||||
if (self.numel() == 0) {
|
if (self.sym_numel() == 0) {
|
||||||
if (num_classes <= 0) {
|
if (num_classes <= 0) {
|
||||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||||
} else {
|
} else {
|
||||||
shape.push_back(num_classes);
|
shape.emplace_back(num_classes);
|
||||||
return at::empty(shape, self.options());
|
return at::empty_symint(shape, self.options());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shape.push_back(num_classes);
|
shape.emplace_back(num_classes);
|
||||||
Tensor ret = at::zeros(shape, self.options());
|
Tensor ret = at::zeros_symint(shape, self.options());
|
||||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -9192,6 +9192,47 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(counter.frame_count, 2)
|
self.assertEqual(counter.frame_count, 2)
|
||||||
self.assertEqual(counter.op_count, 2)
|
self.assertEqual(counter.op_count, 2)
|
||||||
|
|
||||||
|
def test_jacfwd_one_hot_dynamic_compile(self):
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
MAX, BATCH = 3, 37
|
||||||
|
|
||||||
|
def func(x, idxs):
|
||||||
|
return x.square() * F.one_hot(idxs, MAX)
|
||||||
|
|
||||||
|
def jacfunc(x, idxs):
|
||||||
|
return torch.func.jacfwd(func, argnums=(0,))(x, idxs)
|
||||||
|
|
||||||
|
idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64)
|
||||||
|
x = torch.rand((BATCH, MAX), dtype=torch.float64)
|
||||||
|
eager = jacfunc(x, idxs)
|
||||||
|
|
||||||
|
compiled = torch.compile(jacfunc, backend="eager", dynamic=True)
|
||||||
|
out_comp = compiled(x, idxs)
|
||||||
|
self.assertEqual(eager[0], out_comp[0])
|
||||||
|
|
||||||
|
def test_tracing_nested_py_tree_mixed_all(self):
|
||||||
|
def fn(xs):
|
||||||
|
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||||
|
res = [x.clone() for x in flat_xs]
|
||||||
|
return python_pytree.tree_unflatten(res, spec)
|
||||||
|
|
||||||
|
xs = [torch.tensor(i) for i in range(3)]
|
||||||
|
xsa = (xs, xs)
|
||||||
|
xsb = {"aa": xsa, "ab": xs}
|
||||||
|
xsl = {
|
||||||
|
"a": xs,
|
||||||
|
"b": xsa,
|
||||||
|
"c": xsb,
|
||||||
|
}
|
||||||
|
|
||||||
|
counter = CompileCounter()
|
||||||
|
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||||
|
real_out = fn(xsl)
|
||||||
|
self.assertEqual(comp_out, real_out)
|
||||||
|
self.assertEqual(counter.frame_count, 1)
|
||||||
|
self.assertEqual(counter.op_count, 18)
|
||||||
|
|
||||||
def test_any_all_symnode(self):
|
def test_any_all_symnode(self):
|
||||||
cnt = CompileCounter()
|
cnt = CompileCounter()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user