From 5537792307c9f354e1f7b356cd3cb15e452bec0a Mon Sep 17 00:00:00 2001 From: nima10khodaveisi Date: Wed, 22 Mar 2023 14:19:59 +0000 Subject: [PATCH] [dynamo] handle dim in size kwargs (#96992) (#97098) Fixes #96992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97098 Approved by: https://github.com/ezyang --- test/dynamo/test_misc.py | 11 +++++++++++ torch/_dynamo/variables/tensor.py | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index aea5fd32473f..d86a6b1a87f7 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1801,6 +1801,17 @@ class MiscTests(torch._dynamo.test_case.TestCase): self.assertTrue(same(ref, res)) + def test_size_dim(self): + cnts = torch._dynamo.testing.CompileCounter() + + def fn(x, dim): + return x.size(dim=dim) + + opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) + x = torch.empty([4, 9, 8]) + self.assertTrue(opt_fn(x, 1) == 9) + self.assertTrue(opt_fn(x, -2) == 9) + def test_torch_seed(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index afbffe155a28..23ad79636bd9 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -12,7 +12,6 @@ from .. import config, variables from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource - from ..utils import ( fqn, get_fake_value, @@ -252,6 +251,13 @@ class TensorVariable(VariableTracker): elif name == "size" and self.size is not None: sizes = [variables.ConstantVariable(x) for x in self.size] constant_result = SizeVariable(sizes, **options) + + if "dim" in kwargs: + dim = kwargs.pop("dim") + constant_result = constant_result.call_method( + tx, "__getitem__", [dim], {} + ) + elif name == "size" and self.size is None and config.dynamic_shapes: return wrap_fx_proxy( tx,