Compare commits

...

1 Commits

Author SHA1 Message Date
1a8934af84 [dynamo] Graph break on random_ op
Fixes https://github.com/pytorch/pytorch/issues/121621

ghstack-source-id: 098b44305ae2aaab334e6973f6e94f937e61f9a0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130222
2024-07-07 15:29:48 -07:00
2 changed files with 32 additions and 10 deletions

View File

@ -5299,6 +5299,17 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
# the second call causes a failure
m()
# https://github.com/pytorch/pytorch/issues/121621
def test_tensor_random(self):
def random_op(tensor, params):
res = tensor.random_(**params)
return res
random_op = torch.compile(random_op)
params = {"from": -10, "to": 10}
tensor = torch.randn([2, 3])
res = random_op(tensor, params)
instantiate_parametrized_tests(ReproTests)

View File

@ -1495,16 +1495,27 @@ class InstructionTranslatorBase(
null = self.pop()
assert isinstance(null, NullVariable)
if (
isinstance(fn, GetAttrVariable)
and isinstance(fn.obj, TensorVariable)
and fn.name == "view"
and isinstance(argsvars, (ConstantVariable, TensorVariable))
):
# Hack to handle special case in some bert models. Converts
# x.view(*shape) into x.view(shape), which is correct for view()
# but not generally. See test_transpose_for_scores().
argsvars = TupleVariable([argsvars])
if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable):
if fn.name == "view" and isinstance(
argsvars, (ConstantVariable, TensorVariable)
):
# Hack to handle special case in some bert models. Converts
# x.view(*shape) into x.view(shape), which is correct for view()
# but not generally. See test_transpose_for_scores().
argsvars = TupleVariable([argsvars])
elif (
fn.name == "random_"
and isinstance(argsvars, TupleVariable)
and len(argsvars.items) == 0
and isinstance(kwargsvars, ConstDictVariable)
and ConstantVariable.create("from") in kwargsvars
):
# `from`` is python keyword. Adding random_ with `from` in the
# Fx graph causes syntax error. Even if we convert the kwargs to
# args, aot_autograd/inductor while lowering generates
# aten.random.from, again causing syntax errors. Since this
# usecase is uncommon, graph break.
unimplemented("random_ op is called with from keyword")
if not isinstance(
argsvars, BaseListVariable