mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[BE][5/6] fix typos in test/ (test/dynamo/) (#157639)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157639 Approved by: https://github.com/yewentao256, https://github.com/jansel ghstack dependencies: #157638
This commit is contained in:
committed by
PyTorch MergeBot
parent
17687eb792
commit
02715d0876
@ -1169,7 +1169,6 @@ exclude_patterns = [
|
|||||||
'test/**',
|
'test/**',
|
||||||
'test/test_*',
|
'test/test_*',
|
||||||
'test/[a-hA-h]*/**',
|
'test/[a-hA-h]*/**',
|
||||||
'test/dynamo/**',
|
|
||||||
'test/distributed/**',
|
'test/distributed/**',
|
||||||
'torch/**',
|
'torch/**',
|
||||||
'torch/_*/**',
|
'torch/_*/**',
|
||||||
|
@ -1235,7 +1235,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||||||
def test_donated_buffer2(self):
|
def test_donated_buffer2(self):
|
||||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||||
|
|
||||||
# we will re-use the graph for g across f1 and f2
|
# we will reuse the graph for g across f1 and f2
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def g(activation, param2):
|
def g(activation, param2):
|
||||||
return torch.matmul(activation, param2)
|
return torch.matmul(activation, param2)
|
||||||
@ -1257,7 +1257,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
|||||||
def test_donated_buffer3(self):
|
def test_donated_buffer3(self):
|
||||||
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers"
|
||||||
|
|
||||||
# we will re-use the graph for g across f1 and f2
|
# we will reuse the graph for g across f1 and f2
|
||||||
@torch.compile()
|
@torch.compile()
|
||||||
def g(activation, param2):
|
def g(activation, param2):
|
||||||
return torch.matmul(activation, param2)
|
return torch.matmul(activation, param2)
|
||||||
|
@ -471,7 +471,7 @@ class AOTAutogradCacheTests(InductorTestCase):
|
|||||||
)
|
)
|
||||||
def test_view_replay_bypass(self):
|
def test_view_replay_bypass(self):
|
||||||
"""
|
"""
|
||||||
Shoud bypass when view replay is turned on
|
Should bypass when view replay is turned on
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def fn(a):
|
def fn(a):
|
||||||
|
@ -1429,7 +1429,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
result = grad_output * dx + grad_dx * 6 * x
|
result = grad_output * dx + grad_dx * 6 * x
|
||||||
# Intentionally return a wrong value to test if the backward is triggered twice.
|
# Intentionally return a wrong value to test if the backward is triggered twice.
|
||||||
# Since if the first MyCube.apply returns values w/o requires_grad=True,
|
# Since if the first MyCube.apply returns values w/o requires_grad=True,
|
||||||
# this backward would be only triggered once (the first MyCube.appy call),
|
# this backward would be only triggered once (the first MyCube.apply call),
|
||||||
# as the second MyCube.apply is inlined by Dynamo and the corresponding backward
|
# as the second MyCube.apply is inlined by Dynamo and the corresponding backward
|
||||||
# would be generated by autograd engine.
|
# would be generated by autograd engine.
|
||||||
return result * 0.5
|
return result * 0.5
|
||||||
|
@ -388,7 +388,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
ref1 = fn(x, s1, s1)
|
ref1 = fn(x, s1, s1)
|
||||||
res1 = opt_fn(x, s1, s1)
|
res1 = opt_fn(x, s1, s1)
|
||||||
# We have a re-compilation because of chaning inputs
|
# We have a re-compilation because of changing inputs
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
self.assertEqual(ref1, res1)
|
self.assertEqual(ref1, res1)
|
||||||
|
|
||||||
@ -403,7 +403,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
ref0 = fn(x, s0, s1)
|
ref0 = fn(x, s0, s1)
|
||||||
res0 = opt_fn(x, s0, s1)
|
res0 = opt_fn(x, s0, s1)
|
||||||
# We have a re-compilation because of chaning inputs
|
# We have a re-compilation because of changing inputs
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
self.assertEqual(ref0, res0)
|
self.assertEqual(ref0, res0)
|
||||||
|
|
||||||
|
@ -904,7 +904,7 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
# Dynamo should not cause a graph break here because it knows that
|
# Dynamo should not cause a graph break here because it knows that
|
||||||
# the existing proxy cant point to this new dict
|
# the existing proxy can't point to this new dict
|
||||||
other_dict = {}
|
other_dict = {}
|
||||||
other_dict["d"] = 4
|
other_dict["d"] = 4
|
||||||
y = torch.sin(x * mp["c"])
|
y = torch.sin(x * mp["c"])
|
||||||
|
@ -292,7 +292,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(4)
|
x = torch.randn(4)
|
||||||
fn(x)
|
fn(x)
|
||||||
# Cant use fullgraph=True because RERAISE is not supported
|
# Can't use fullgraph=True because RERAISE is not supported
|
||||||
opt_fn = torch.compile(fn, backend="eager")
|
opt_fn = torch.compile(fn, backend="eager")
|
||||||
opt_fn(x)
|
opt_fn(x)
|
||||||
|
|
||||||
|
@ -3536,7 +3536,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
[3, 3, 4, 5],
|
[3, 3, 4, 5],
|
||||||
[true_graph, true_graph, false_graph, false_graph],
|
[true_graph, true_graph, false_graph, false_graph],
|
||||||
[true_guard_code, true_guard_code, false_guard_code, false_guard_code],
|
[true_guard_code, true_guard_code, false_guard_code, false_guard_code],
|
||||||
# Outter shape env should have no guards in it because we never specialize on the outter symbool.
|
# Outer shape env should have no guards in it because we never specialize on the outer symbool.
|
||||||
[[], [], [], []],
|
[[], [], [], []],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1240,7 +1240,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
def test_inline_softmax(x, y):
|
def test_inline_softmax(x, y):
|
||||||
# This is common in sme huggingface models
|
# This is common in some huggingface models
|
||||||
return torch.nn.Softmax(dim=-1)(x + y * 2)
|
return torch.nn.Softmax(dim=-1)(x + y * 2)
|
||||||
|
|
||||||
@make_test
|
@make_test
|
||||||
|
@ -2133,7 +2133,7 @@ def forward(self, child : torch.Tensor, const_unused : int):
|
|||||||
and node.target == torch.ops.higher_order.cond
|
and node.target == torch.ops.higher_order.cond
|
||||||
):
|
):
|
||||||
_, _, _, operands = node.args
|
_, _, _, operands = node.args
|
||||||
# Since we compile wit dynamic, each branch takes 4 inputs (buffer, x, z, s1)
|
# Since we compile with dynamic, each branch takes 4 inputs (buffer, x, z, s1)
|
||||||
self.assertEqual(len(operands), 4)
|
self.assertEqual(len(operands), 4)
|
||||||
if node.op == "get_attr":
|
if node.op == "get_attr":
|
||||||
if str(node.target) in ("cond_true_0, cond_false_0"):
|
if str(node.target) in ("cond_true_0, cond_false_0"):
|
||||||
|
@ -746,7 +746,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
if cnts:
|
if cnts:
|
||||||
self.assertEqual(cnts.frame_count, 1)
|
self.assertEqual(cnts.frame_count, 1)
|
||||||
# These same exact assertions run on both eager and compiled
|
# These same exact assertions run on both eager and compiled
|
||||||
# X goes to x*2 becaue of mul_
|
# X goes to x*2 because of mul_
|
||||||
self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
|
self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
|
||||||
# This test proves grad aliasing works -
|
# This test proves grad aliasing works -
|
||||||
self.assertEqual(x.grad, b * 5)
|
self.assertEqual(x.grad, b * 5)
|
||||||
|
@ -57,7 +57,7 @@ unittest.expectedFailure(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# These tests do string comparisson on the graphs, and since buffers are now inlined, they
|
# These tests do string comparison on the graphs, and since buffers are now inlined, they
|
||||||
# are named different, resulting in failure
|
# are named different, resulting in failure
|
||||||
unittest.expectedFailure(
|
unittest.expectedFailure(
|
||||||
InlineAndInstallExportTests.test_param_buffer_safe_from_mutation_simple_inline_and_install # noqa: F821
|
InlineAndInstallExportTests.test_param_buffer_safe_from_mutation_simple_inline_and_install # noqa: F821
|
||||||
|
@ -64,7 +64,7 @@ class TestMetricsContext(TestCase):
|
|||||||
|
|
||||||
def test_update_disallow_overwrite(self):
|
def test_update_disallow_overwrite(self):
|
||||||
"""
|
"""
|
||||||
Validate update won't overwite.
|
Validate update won't overwrite.
|
||||||
"""
|
"""
|
||||||
with MetricsContext(self._on_exit) as context:
|
with MetricsContext(self._on_exit) as context:
|
||||||
context.update({"m1": 1, "m2": 2})
|
context.update({"m1": 1, "m2": 2})
|
||||||
@ -73,7 +73,7 @@ class TestMetricsContext(TestCase):
|
|||||||
|
|
||||||
def test_update_allow_overwrite(self):
|
def test_update_allow_overwrite(self):
|
||||||
"""
|
"""
|
||||||
Validate update will overwite when given param.
|
Validate update will overwrite when given param.
|
||||||
"""
|
"""
|
||||||
with MetricsContext(self._on_exit) as context:
|
with MetricsContext(self._on_exit) as context:
|
||||||
context.update({"m1": 1, "m2": 2})
|
context.update({"m1": 1, "m2": 2})
|
||||||
|
@ -4052,7 +4052,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
y = x
|
y = x
|
||||||
|
|
||||||
def make_x_get_set():
|
def make_x_get_set():
|
||||||
# NOTE: this `x` is a different cell object than the outter `x`.
|
# NOTE: this `x` is a different cell object than the outer `x`.
|
||||||
x = y
|
x = y
|
||||||
|
|
||||||
def set_x(v):
|
def set_x(v):
|
||||||
@ -4844,7 +4844,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
def test_id_guarded_object(self):
|
def test_id_guarded_object(self):
|
||||||
class UDO:
|
class UserDefinedObject:
|
||||||
@torch.compile(backend="eager")
|
@torch.compile(backend="eager")
|
||||||
def call(self, x, ref_id):
|
def call(self, x, ref_id):
|
||||||
self_id = id(self)
|
self_id = id(self)
|
||||||
@ -4857,11 +4857,11 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
# Make sure we do recompile when id(self) is executed on
|
# Make sure we do recompile when id(self) is executed on
|
||||||
# different self objects.
|
# different self objects.
|
||||||
x = torch.ones(2)
|
x = torch.ones(2)
|
||||||
obj1 = UDO()
|
obj1 = UserDefinedObject()
|
||||||
obj1_id = id(obj1)
|
obj1_id = id(obj1)
|
||||||
self.assertEqual(obj1.call(x, obj1_id), torch.ones(2))
|
self.assertEqual(obj1.call(x, obj1_id), torch.ones(2))
|
||||||
|
|
||||||
obj2 = UDO()
|
obj2 = UserDefinedObject()
|
||||||
# if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
|
# if we do not install ID_MATCH: ___check_obj_id(L['self'], xxx) this fails.
|
||||||
self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2))
|
self.assertEqual(obj2.call(x, obj1_id), torch.zeros(2))
|
||||||
|
|
||||||
@ -8698,7 +8698,7 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||||||
),
|
),
|
||||||
testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"),
|
testcase(expr="f(m.n[0], '1').x.y.z", expected="f(_var3, '1').x.y.z"),
|
||||||
testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"),
|
testcase(expr="f(m.n[0], '2').x.y.z", expected="f(_var3, '2').x.y.z"),
|
||||||
# The whole expressiong gets CSE-d, as well as all of its sub-expressions.
|
# The whole expression gets CSE-d, as well as all of its sub-expressions.
|
||||||
testcase(
|
testcase(
|
||||||
expr="self.g(a, b).k",
|
expr="self.g(a, b).k",
|
||||||
preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
|
preface=["_var4 = self.g", "_var5 = _var4(a, b)", "_var6 = _var5.k"],
|
||||||
@ -10493,11 +10493,11 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_leaves(self):
|
def test_pytree_tree_leaves(self):
|
||||||
implemtations = [("python", python_pytree)]
|
implementations = [("python", python_pytree)]
|
||||||
if cxx_pytree is not None:
|
if cxx_pytree is not None:
|
||||||
implemtations.append(("cxx", cxx_pytree))
|
implementations.append(("cxx", cxx_pytree))
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implementations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
||||||
def fn(x):
|
def fn(x):
|
||||||
@ -10527,11 +10527,11 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_flatten_unflatten(self):
|
def test_pytree_tree_flatten_unflatten(self):
|
||||||
implemtations = [("python", python_pytree)]
|
implementations = [("python", python_pytree)]
|
||||||
if cxx_pytree is not None:
|
if cxx_pytree is not None:
|
||||||
implemtations.append(("cxx", cxx_pytree))
|
implementations.append(("cxx", cxx_pytree))
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implementations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
@ -10578,11 +10578,11 @@ def ___make_guard_fn():
|
|||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
def test_pytree_tree_map(self):
|
def test_pytree_tree_map(self):
|
||||||
implemtations = [("python", python_pytree)]
|
implementations = [("python", python_pytree)]
|
||||||
if cxx_pytree is not None:
|
if cxx_pytree is not None:
|
||||||
implemtations.append(("cxx", cxx_pytree))
|
implementations.append(("cxx", cxx_pytree))
|
||||||
|
|
||||||
for name, module in implemtations:
|
for name, module in implementations:
|
||||||
with self.subTest(f"pytree implement: {name}"):
|
with self.subTest(f"pytree implement: {name}"):
|
||||||
|
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
@ -11731,7 +11731,7 @@ fn
|
|||||||
|
|
||||||
# Ensure that the generated graph returns only one output. We want the
|
# Ensure that the generated graph returns only one output. We want the
|
||||||
# add_ on the grad to be part of the graph itself, so that inductor can
|
# add_ on the grad to be part of the graph itself, so that inductor can
|
||||||
# theoretically move the add_ and resutling copy_ nodes at the right
|
# theoretically move the add_ and resulting copy_ nodes at the right
|
||||||
# place to free memory.
|
# place to free memory.
|
||||||
self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1)
|
self.assertEqual(len(list(cnt.graphs[0].graph.nodes)[-1].all_input_nodes), 1)
|
||||||
self.assertEqual(z, ref_y)
|
self.assertEqual(z, ref_y)
|
||||||
@ -12293,7 +12293,7 @@ fn
|
|||||||
self.ne_called = False
|
self.ne_called = False
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
# ne_called attr is later checked to ensure that overrideen
|
# ne_called attr is later checked to ensure that overridden
|
||||||
# `__ne__` is traced
|
# `__ne__` is traced
|
||||||
self.ne_called = True
|
self.ne_called = True
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
@ -1987,7 +1987,7 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
|||||||
# Check order of _modules
|
# Check order of _modules
|
||||||
def fn(x):
|
def fn(x):
|
||||||
for idx, p in enumerate(mod.modules()):
|
for idx, p in enumerate(mod.modules()):
|
||||||
# Something silly to force depedency on the order
|
# Something silly to force dependency on the order
|
||||||
x += coeffs_for_mod[p] * coeffs[idx]
|
x += coeffs_for_mod[p] * coeffs[idx]
|
||||||
for idx, p in enumerate(mod.named_modules()):
|
for idx, p in enumerate(mod.named_modules()):
|
||||||
x += coeffs_for_mod[p[1]] * coeffs[idx]
|
x += coeffs_for_mod[p[1]] * coeffs[idx]
|
||||||
|
@ -6159,7 +6159,7 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|||||||
self.assertEqual(out_ref, out_test)
|
self.assertEqual(out_ref, out_test)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
# This test will fail as flip in combination with particular input lenghts
|
# This test will fail as flip in combination with particular input lengths
|
||||||
# produces weird results.
|
# produces weird results.
|
||||||
# This is under investigations in
|
# This is under investigations in
|
||||||
# https://github.com/pytorch/pytorch/issues/131805
|
# https://github.com/pytorch/pytorch/issues/131805
|
||||||
@ -7449,7 +7449,7 @@ class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|||||||
# *are* saved for backward, and become back inputs.
|
# *are* saved for backward, and become back inputs.
|
||||||
# The easier-to-test thing I'm checking for here is that the recompute
|
# The easier-to-test thing I'm checking for here is that the recompute
|
||||||
# on primals_2 happens in the backward. With the recompute,
|
# on primals_2 happens in the backward. With the recompute,
|
||||||
# there are 5 _to_copy ops in the backwrad. Without it, there are 4
|
# there are 5 _to_copy ops in the backward. Without it, there are 4
|
||||||
# (aka if you set torch._functorch.config.treat_parameters_as_free_to_save = False)
|
# (aka if you set torch._functorch.config.treat_parameters_as_free_to_save = False)
|
||||||
self.assertEqual(mode.ops_counter[torch.ops.aten._to_copy.default], 5)
|
self.assertEqual(mode.ops_counter[torch.ops.aten._to_copy.default], 5)
|
||||||
|
|
||||||
|
@ -1368,7 +1368,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
||||||
|
|
||||||
# Cannot re-use the version from AOTAutograd, since that uses python functional tensors.
|
# Cannot reuse the version from AOTAutograd, since that uses python functional tensors.
|
||||||
def to_fun(x):
|
def to_fun(x):
|
||||||
x_functional = torch._to_functional_tensor(x)
|
x_functional = torch._to_functional_tensor(x)
|
||||||
torch._mirror_autograd_meta_to(x, x_functional)
|
torch._mirror_autograd_meta_to(x, x_functional)
|
||||||
@ -2017,7 +2017,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
exp_frame_count=[1, 1, 2, 2],
|
exp_frame_count=[1, 1, 2, 2],
|
||||||
exp_shape_env_guards=[
|
exp_shape_env_guards=[
|
||||||
[],
|
[],
|
||||||
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
|
# s0 is specialized and guarded in outer shape_env when dynamo checks the guards
|
||||||
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
|
["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
|
||||||
[
|
[
|
||||||
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
"Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
|
||||||
@ -2039,7 +2039,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
exp_frame_count=[1, 1, 2, 2],
|
exp_frame_count=[1, 1, 2, 2],
|
||||||
exp_shape_env_guards=[
|
exp_shape_env_guards=[
|
||||||
[],
|
[],
|
||||||
# s0 is specialized and guarded in outter shape_env when dynamo checks the guards
|
# s0 is specialized and guarded in outer shape_env when dynamo checks the guards
|
||||||
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
|
["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
|
||||||
[
|
[
|
||||||
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
"Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
|
||||||
@ -3085,7 +3085,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
# triggers the eager logic to run, updating the counter and registry.
|
# triggers the eager logic to run, updating the counter and registry.
|
||||||
#
|
#
|
||||||
# Notably however, compile differs in two ways from eager:
|
# Notably however, compile differs in two ways from eager:
|
||||||
# (1) The order in which the offsets are assigned ids is differnet
|
# (1) The order in which the offsets are assigned ids is different
|
||||||
# the registry would be set in the order the offsets are returned
|
# the registry would be set in the order the offsets are returned
|
||||||
# which is not necessarily the same order as they were constructed.
|
# which is not necessarily the same order as they were constructed.
|
||||||
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
|
# (2) If a NestedTensor is not returned, then the AOTAutograd wrapping
|
||||||
|
@ -401,7 +401,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|||||||
y = torch.randn(3)
|
y = torch.randn(3)
|
||||||
self.assertEqual(opt_fn(x, y), fn(x, y))
|
self.assertEqual(opt_fn(x, y), fn(x, y))
|
||||||
self.assertEqual(opt_fn(x, x), fn(x, x))
|
self.assertEqual(opt_fn(x, x), fn(x, x))
|
||||||
# NB: This COULD validly be 2, but we don't test disjointness in the
|
# NB: This COULD validly be 2, but we don't test disjointedness in the
|
||||||
# guards for when x and y didn't duck size together, so we end up
|
# guards for when x and y didn't duck size together, so we end up
|
||||||
# with a generic graph that also works when x and y happen to duck
|
# with a generic graph that also works when x and y happen to duck
|
||||||
# size together.
|
# size together.
|
||||||
|
@ -126,7 +126,7 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject
|
|||||||
torch_name_rule_map = {}
|
torch_name_rule_map = {}
|
||||||
|
|
||||||
# In some platforms, these functions were loaded as classes instead of functions.
|
# In some platforms, these functions were loaded as classes instead of functions.
|
||||||
# To mitigate these weired cases, we need this special check.
|
# To mitigate these weird cases, we need this special check.
|
||||||
def is_special_functions(obj):
|
def is_special_functions(obj):
|
||||||
return hashable(obj) and obj in {
|
return hashable(obj) and obj in {
|
||||||
torch._C._cuda_isCurrentStreamCapturing,
|
torch._C._cuda_isCurrentStreamCapturing,
|
||||||
|
Reference in New Issue
Block a user