[dynamo] handle dim in size kwargs (#96992) (#97098)

Fixes #96992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97098
Approved by: https://github.com/ezyang
This commit is contained in:
nima10khodaveisi
2023-03-22 14:19:59 +00:00
committed by PyTorch MergeBot
parent 9d5ac03b9a
commit 5537792307
2 changed files with 18 additions and 1 deletions

View File

@ -1801,6 +1801,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res))
def test_size_dim(self):
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, dim):
return x.size(dim=dim)
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
x = torch.empty([4, 9, 8])
self.assertTrue(opt_fn(x, 1) == 9)
self.assertTrue(opt_fn(x, -2) == 9)
def test_torch_seed(self):
cnts = torch._dynamo.testing.CompileCounter()

View File

@ -12,7 +12,6 @@ from .. import config, variables
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..source import AttrSource
from ..utils import (
fqn,
get_fake_value,
@ -252,6 +251,13 @@ class TensorVariable(VariableTracker):
elif name == "size" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
constant_result = SizeVariable(sizes, **options)
if "dim" in kwargs:
dim = kwargs.pop("dim")
constant_result = constant_result.call_method(
tx, "__getitem__", [dim], {}
)
elif name == "size" and self.size is None and config.dynamic_shapes:
return wrap_fx_proxy(
tx,