mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
feat(dynamo): IS#160752 make F.one_hot work with jacfwd + torch.compile(dynamic=True) (#160837)
Fixes #160752 # Background: `torch.func.jacfwd` is implemented as vmap over forward-mode JVP. With torch.compile(dynamic=True), FakeTensor + SymInt shape reasoning is used while tracing through the transform. The old vmap rule for one_hot decomposed into “zeros_symint + scatter,” which interacted poorly with the transform stack and dynamic shapes, leading to failures mid-trace. Using a functional equality construction makes one_hot composable with vmap/JVP and friendly to dynamic shape tracing. # Changes: - functorch vmap batching rule for `aten::one_hot` now uses a purely functional formulation: - Replace “zeros + scatter” with eq(self.unsqueeze(-1), arange(num_classes)).to(kLong) under FuncTorchBatched. - one_hot native path remains unchanged for regular eager; vmap transform no longer relies on scatter, which was fragile under dynamic shape tracing. The minimal repro from the issue is now fixed: ```python import torch import torch.nn.functional as F MAX, BATCH = 3, 37 def func(x, idxs): return x.square() * F.one_hot(idxs, MAX) def jacfunc(x, idxs): return torch.func.jacfwd(func, argnums=0)(x, idxs) idxs = torch.randint(MAX, (BATCH,), dtype=torch.int64) x = torch.rand((BATCH, MAX), dtype=torch.float64) # eager out_eager = jacfunc(x, idxs) # compiled dynamic jacfunc_c = torch.compile(jacfunc, dynamic=True) out_comp = jacfunc_c(x, idxs) torch.testing.assert_close(out_eager, out_comp) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160837 Approved by: https://github.com/guilhermeleobas, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
4f400ab520
commit
b4fd47179e
@ -213,40 +213,22 @@ static cudnn_grid_sample_backward_batch_rule(
|
||||
return grid_sample_backward_helper_out(std::move(bw_out), 0, 0, bdim_size);
|
||||
}
|
||||
|
||||
// TODO: replace with targetable functionalization
|
||||
// uses functional formulation for one_hot under vmap to be compatible with
|
||||
// fakeTensor/dynamic shapes and compiled functorch transforms.
|
||||
// mirrors the meta path in aten/src/ATen/native/Onehot.cpp,
|
||||
// but requires explicit positive num_classes under vmap to avoid
|
||||
// data-dependent output shapes.
|
||||
static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes) {
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
// disallow implicit inference under vmap; this would be data-dependent
|
||||
// and is intentionally guarded by Dynamo in torch/_dynamo/variables/torch.py.
|
||||
TORCH_CHECK(num_classes > 0, "When vmap-ing torch.nn.functional.one_hot, please "
|
||||
"provide an explicit positive num_classes argument.");
|
||||
|
||||
// Disabling all of the following checks. This is OK because scatter has checks too.
|
||||
// Maybe one_hot should be a primitive wrt autograd so we don't have to deal with this.
|
||||
// // non-empty tensor
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //for cuda, rely on device assert thrown by scatter
|
||||
// TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
|
||||
// }
|
||||
// if (self.device().type() != at::kCUDA) {
|
||||
// //rely on device asserts from scatter to avoid sync here
|
||||
// TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
|
||||
// }
|
||||
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
return ret.scatter(-1, self.unsqueeze(-1), 1);
|
||||
const auto options = self.options();
|
||||
at::Tensor index = at::arange(num_classes, options);
|
||||
return at::eq(self.unsqueeze(-1), index).to(at::kLong);
|
||||
}
|
||||
|
||||
template <typename A, A a, typename C>
|
||||
|
@ -34,16 +34,16 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = self.sizes().vec();
|
||||
auto shape = self.sym_sizes().vec();
|
||||
|
||||
// empty tensor could be converted to one hot representation,
|
||||
// but shape inference is not possible.
|
||||
if (self.numel() == 0) {
|
||||
if (self.sym_numel() == 0) {
|
||||
if (num_classes <= 0) {
|
||||
TORCH_CHECK(false, "Can not infer total number of classes from empty tensor.");
|
||||
} else {
|
||||
shape.push_back(num_classes);
|
||||
return at::empty(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
return at::empty_symint(shape, self.options());
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,8 +66,8 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
}
|
||||
}
|
||||
|
||||
shape.push_back(num_classes);
|
||||
Tensor ret = at::zeros(shape, self.options());
|
||||
shape.emplace_back(num_classes);
|
||||
Tensor ret = at::zeros_symint(shape, self.options());
|
||||
ret.scatter_(-1, self.unsqueeze(-1), 1);
|
||||
return ret;
|
||||
}
|
||||
|
Reference in New Issue
Block a user