mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #96992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97098 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d5ac03b9a
commit
5537792307
@ -1801,6 +1801,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
def test_size_dim(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
def fn(x, dim):
|
||||
return x.size(dim=dim)
|
||||
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
x = torch.empty([4, 9, 8])
|
||||
self.assertTrue(opt_fn(x, 1) == 9)
|
||||
self.assertTrue(opt_fn(x, -2) == 9)
|
||||
|
||||
def test_torch_seed(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
|
@ -12,7 +12,6 @@ from .. import config, variables
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder
|
||||
from ..source import AttrSource
|
||||
|
||||
from ..utils import (
|
||||
fqn,
|
||||
get_fake_value,
|
||||
@ -252,6 +251,13 @@ class TensorVariable(VariableTracker):
|
||||
elif name == "size" and self.size is not None:
|
||||
sizes = [variables.ConstantVariable(x) for x in self.size]
|
||||
constant_result = SizeVariable(sizes, **options)
|
||||
|
||||
if "dim" in kwargs:
|
||||
dim = kwargs.pop("dim")
|
||||
constant_result = constant_result.call_method(
|
||||
tx, "__getitem__", [dim], {}
|
||||
)
|
||||
|
||||
elif name == "size" and self.size is None and config.dynamic_shapes:
|
||||
return wrap_fx_proxy(
|
||||
tx,
|
||||
|
Reference in New Issue
Block a user