[dynamo] handle dim in size kwargs (#96992) (#97098)

Fixes #96992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97098
Approved by: https://github.com/ezyang
This commit is contained in:
nima10khodaveisi
2023-03-22 14:19:59 +00:00
committed by PyTorch MergeBot
parent 9d5ac03b9a
commit 5537792307
2 changed files with 18 additions and 1 deletions

View File

@ -1801,6 +1801,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
self.assertTrue(same(ref, res)) self.assertTrue(same(ref, res))
def test_size_dim(self):
cnts = torch._dynamo.testing.CompileCounter()
def fn(x, dim):
return x.size(dim=dim)
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
x = torch.empty([4, 9, 8])
self.assertTrue(opt_fn(x, 1) == 9)
self.assertTrue(opt_fn(x, -2) == 9)
def test_torch_seed(self): def test_torch_seed(self):
cnts = torch._dynamo.testing.CompileCounter() cnts = torch._dynamo.testing.CompileCounter()

View File

@ -12,7 +12,6 @@ from .. import config, variables
from ..exc import unimplemented from ..exc import unimplemented
from ..guards import GuardBuilder from ..guards import GuardBuilder
from ..source import AttrSource from ..source import AttrSource
from ..utils import ( from ..utils import (
fqn, fqn,
get_fake_value, get_fake_value,
@ -252,6 +251,13 @@ class TensorVariable(VariableTracker):
elif name == "size" and self.size is not None: elif name == "size" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size] sizes = [variables.ConstantVariable(x) for x in self.size]
constant_result = SizeVariable(sizes, **options) constant_result = SizeVariable(sizes, **options)
if "dim" in kwargs:
dim = kwargs.pop("dim")
constant_result = constant_result.call_method(
tx, "__getitem__", [dim], {}
)
elif name == "size" and self.size is None and config.dynamic_shapes: elif name == "size" and self.size is None and config.dynamic_shapes:
return wrap_fx_proxy( return wrap_fx_proxy(
tx, tx,