mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "use sym_numel, to allow fake tensors to work (#163831)"
This reverts commit e71c75680f2d6ce5f61ad4b2125f4934087762eb. Reverted https://github.com/pytorch/pytorch/pull/163831 on behalf of https://github.com/isuruf due to test failure on mps introduced ([comment](https://github.com/pytorch/pytorch/pull/163831#issuecomment-3400131730))
This commit is contained in:
@ -21,7 +21,7 @@ namespace {
|
|||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
|
|
||||||
Tensor _triu_mask(c10::SymInt n, int64_t dims, bool diagonal, TensorOptions opt) {
|
Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
|
||||||
// get a mask that has value 1 whose indices satisfies i < j < k < ...
|
// get a mask that has value 1 whose indices satisfies i < j < k < ...
|
||||||
// or i <= j <= k <= ... (depending on diagonal)
|
// or i <= j <= k <= ... (depending on diagonal)
|
||||||
Tensor range = at::arange(n, opt.dtype(kLong));
|
Tensor range = at::arange(n, opt.dtype(kLong));
|
||||||
@ -63,7 +63,7 @@ Tensor combinations(const Tensor& self, int64_t r, bool with_replacement) {
|
|||||||
if (r == 0) {
|
if (r == 0) {
|
||||||
return at::empty({0}, self.options());
|
return at::empty({0}, self.options());
|
||||||
}
|
}
|
||||||
const auto num_elements = self.sym_numel();
|
int64_t num_elements = self.numel();
|
||||||
std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self), "ij");
|
std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self), "ij");
|
||||||
Tensor mask = _triu_mask(num_elements, r, with_replacement, self.options());
|
Tensor mask = _triu_mask(num_elements, r, with_replacement, self.options());
|
||||||
for(Tensor &t : grids) {
|
for(Tensor &t : grids) {
|
||||||
|
@ -653,33 +653,6 @@ class TestInductorDynamic(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(foo_c(t, y), foobar(t, y))
|
self.assertEqual(foo_c(t, y), foobar(t, y))
|
||||||
|
|
||||||
@parametrize("with_replacement", [False, True])
|
|
||||||
def test_dynamic_shapes_r2_matches_eager(self, with_replacement):
|
|
||||||
def _eager(x, r):
|
|
||||||
out = torch.combinations(
|
|
||||||
x.flatten(), r=r, with_replacement=with_replacement
|
|
||||||
)
|
|
||||||
# Canonicalize for stable comparison
|
|
||||||
return out.to(torch.float32).sort(dim=0).values
|
|
||||||
|
|
||||||
def _compiled(r):
|
|
||||||
def fn(x):
|
|
||||||
return torch.combinations(
|
|
||||||
x.flatten(), r=r, with_replacement=with_replacement
|
|
||||||
)
|
|
||||||
|
|
||||||
# The original bug repro failed under aot_eager + dynamic=True
|
|
||||||
return torch.compile(fn, backend="aot_eager", dynamic=True)
|
|
||||||
|
|
||||||
def _assert_match(compiled, x, r):
|
|
||||||
out = compiled(x)
|
|
||||||
exp = _eager(x, r=r)
|
|
||||||
self.assertEqual(out.to(torch.float32).sort(dim=0).values, exp)
|
|
||||||
|
|
||||||
compiled = _compiled(r=2)
|
|
||||||
_assert_match(compiled, torch.tensor([1, 2, 3, 4], dtype=torch.int64), r=2)
|
|
||||||
_assert_match(compiled, torch.tensor([5, 6, 7], dtype=torch.int64), r=2)
|
|
||||||
|
|
||||||
def test_floor(self):
|
def test_floor(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
n = x.size(-1)
|
n = x.size(-1)
|
||||||
|
Reference in New Issue
Block a user