mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Fixes #96992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97098 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9d5ac03b9a
commit
5537792307
@ -1801,6 +1801,17 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(same(ref, res))
|
self.assertTrue(same(ref, res))
|
||||||
|
|
||||||
|
def test_size_dim(self):
|
||||||
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
|
def fn(x, dim):
|
||||||
|
return x.size(dim=dim)
|
||||||
|
|
||||||
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||||
|
x = torch.empty([4, 9, 8])
|
||||||
|
self.assertTrue(opt_fn(x, 1) == 9)
|
||||||
|
self.assertTrue(opt_fn(x, -2) == 9)
|
||||||
|
|
||||||
def test_torch_seed(self):
|
def test_torch_seed(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ from .. import config, variables
|
|||||||
from ..exc import unimplemented
|
from ..exc import unimplemented
|
||||||
from ..guards import GuardBuilder
|
from ..guards import GuardBuilder
|
||||||
from ..source import AttrSource
|
from ..source import AttrSource
|
||||||
|
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
fqn,
|
fqn,
|
||||||
get_fake_value,
|
get_fake_value,
|
||||||
@ -252,6 +251,13 @@ class TensorVariable(VariableTracker):
|
|||||||
elif name == "size" and self.size is not None:
|
elif name == "size" and self.size is not None:
|
||||||
sizes = [variables.ConstantVariable(x) for x in self.size]
|
sizes = [variables.ConstantVariable(x) for x in self.size]
|
||||||
constant_result = SizeVariable(sizes, **options)
|
constant_result = SizeVariable(sizes, **options)
|
||||||
|
|
||||||
|
if "dim" in kwargs:
|
||||||
|
dim = kwargs.pop("dim")
|
||||||
|
constant_result = constant_result.call_method(
|
||||||
|
tx, "__getitem__", [dim], {}
|
||||||
|
)
|
||||||
|
|
||||||
elif name == "size" and self.size is None and config.dynamic_shapes:
|
elif name == "size" and self.size is None and config.dynamic_shapes:
|
||||||
return wrap_fx_proxy(
|
return wrap_fx_proxy(
|
||||||
tx,
|
tx,
|
||||||
|
Reference in New Issue
Block a user