mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][pytree] cleanup parameterized pytree tests (#160842)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
01edcd4df8
commit
2fa0520a64
@ -87,12 +87,15 @@ from torch.testing._internal.common_methods_invocations import (
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
parametrize,
|
||||
scoped_load_inline,
|
||||
set_default_dtype,
|
||||
skipIfHpu,
|
||||
skipIfNNModuleInlined,
|
||||
skipIfWindows,
|
||||
subtest,
|
||||
TEST_HPU,
|
||||
TEST_XPU,
|
||||
wrapDeterministicFlagAPITest,
|
||||
@ -101,11 +104,21 @@ from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
pytree_modules = {
|
||||
"python": python_pytree,
|
||||
}
|
||||
if python_pytree._cxx_pytree_dynamo_traceable:
|
||||
import torch.utils._cxx_pytree as cxx_pytree
|
||||
|
||||
pytree_modules["cxx"] = cxx_pytree
|
||||
else:
|
||||
cxx_pytree = None
|
||||
|
||||
parametrize_pytree_module = parametrize(
|
||||
"pytree",
|
||||
[subtest(module, name=name) for name, module in pytree_modules.items()],
|
||||
)
|
||||
|
||||
MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"])
|
||||
T = typing.TypeVar("T")
|
||||
|
||||
@ -9107,71 +9120,6 @@ def ___make_guard_fn():
|
||||
opt = torch.compile(fn, backend="eager")
|
||||
opt()
|
||||
|
||||
def test_tracing_py_tree(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
||||
counter = CompileCounter()
|
||||
torch.compile(fn, backend=counter, fullgraph=True)(xs)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 3)
|
||||
|
||||
def test_tracing_nested_py_tree(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = [xs, xs, xs, xs]
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
def test_tracing_nested_py_tree_tuples(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = (xs, xs, xs, xs)
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
def test_tracing_nested_py_tree_dicts(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = {
|
||||
"a": xs,
|
||||
"b": xs,
|
||||
"c": xs,
|
||||
}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
def test_dynamic_one_hot(self):
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
@ -9188,28 +9136,6 @@ def ___make_guard_fn():
|
||||
self.assertEqual(counter.frame_count, 2)
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
def test_tracing_nested_py_tree_mixed_all(self):
|
||||
def fn(xs):
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return python_pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsa = (xs, xs)
|
||||
xsb = {"aa": xsa, "ab": xs}
|
||||
xsl = {
|
||||
"a": xs,
|
||||
"b": xsa,
|
||||
"c": xsb,
|
||||
}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 18)
|
||||
|
||||
def test_any_all_symnode(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
@ -9236,46 +9162,6 @@ def ___make_guard_fn():
|
||||
self.assertEqual(fn(y3), y3 - 3)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_tracing_py_tree_tensor_subclass(self):
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def fn(xs):
|
||||
nested_xs = [[xs]]
|
||||
flat_xs, spec = python_pytree.tree_flatten(xs)
|
||||
return flat_xs[0].clone()
|
||||
|
||||
# use checkpoint to trigger a "sourceless" tensor subclass
|
||||
def checkpoint_fn(xs):
|
||||
return checkpoint(fn, xs, use_reentrant=True)
|
||||
|
||||
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
|
||||
|
||||
counter = CompileCounter()
|
||||
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
def test_tracing_tree_map_only(self):
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
||||
y = python_pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||
return y
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
||||
xsa = (xs, xs)
|
||||
xsb = {"aa": xsa, "ab": xs}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
|
||||
real_out = fn(xsb)
|
||||
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
@torch._dynamo.config.patch(
|
||||
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||
)
|
||||
@ -10718,139 +10604,6 @@ def ___make_guard_fn():
|
||||
expected = fn(*inps)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_leaves(self):
|
||||
implementations = [("python", python_pytree)]
|
||||
if cxx_pytree is not None:
|
||||
implementations.append(("cxx", cxx_pytree))
|
||||
|
||||
for name, module in implementations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
||||
def fn(x):
|
||||
tree = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
leaves = module.tree_leaves(tree)
|
||||
return leaves
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
expected = fn(x)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_flatten_unflatten(self):
|
||||
implementations = [("python", python_pytree)]
|
||||
if cxx_pytree is not None:
|
||||
implementations.append(("cxx", cxx_pytree))
|
||||
|
||||
for name, module in implementations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
||||
def fn(x, y):
|
||||
tree = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
leaves, treespec = module.tree_flatten(tree)
|
||||
new_leaves = [
|
||||
x - 1,
|
||||
y,
|
||||
x * y,
|
||||
3.0,
|
||||
y - 2,
|
||||
1,
|
||||
torch.zeros(2, 2),
|
||||
2 * y,
|
||||
-y,
|
||||
x + y,
|
||||
x - y,
|
||||
torch.ones(3, 2),
|
||||
1,
|
||||
]
|
||||
new_tree = module.tree_unflatten(new_leaves, treespec)
|
||||
return leaves, new_tree
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
y = torch.randn(3, 2)
|
||||
expected = fn(x, y)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x, y)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_pytree_tree_map(self):
|
||||
implementations = [("python", python_pytree)]
|
||||
if cxx_pytree is not None:
|
||||
implementations.append(("cxx", cxx_pytree))
|
||||
|
||||
for name, module in implementations:
|
||||
with self.subTest(f"pytree implement: {name}"):
|
||||
|
||||
def fn(x, y):
|
||||
tree1 = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
tree2 = collections.OrderedDict(
|
||||
[
|
||||
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
|
||||
("a", [y, y + 1]),
|
||||
("b", y + 2),
|
||||
(
|
||||
"d",
|
||||
{
|
||||
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
|
||||
"e": torch.return_types.qr((2 * y, None)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
return module.tree_map(lambda u, v: (u, v), tree1, tree2)
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
y = torch.randn(3, 2)
|
||||
expected = fn(x, y)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x, y)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_shape_env_no_recording(self):
|
||||
main = ShapeEnv(should_record_events=False)
|
||||
|
||||
@ -12886,6 +12639,257 @@ fn
|
||||
self.assertRaises(Unsupported, f, "1 + j")
|
||||
|
||||
|
||||
class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_pytree(self, pytree):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
|
||||
counter = CompileCounter()
|
||||
torch.compile(fn, backend=counter, fullgraph=True)(xs)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 3)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_nested_pytree(self, pytree):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = [xs, xs, xs, xs]
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_nested_tuples(self, pytree):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = (xs, xs, xs, xs)
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 12)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_nested_dicts(self, pytree):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsl = {
|
||||
"a": xs,
|
||||
"b": xs,
|
||||
"c": xs,
|
||||
}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_nested_mixed_all(self, pytree):
|
||||
def fn(xs):
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
res = [x.clone() for x in flat_xs]
|
||||
return pytree.tree_unflatten(res, spec)
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)]
|
||||
xsa = (xs, xs)
|
||||
xsb = {"aa": xsa, "ab": xs}
|
||||
xsl = {
|
||||
"a": xs,
|
||||
"b": xsa,
|
||||
"c": xsb,
|
||||
}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsl)
|
||||
real_out = fn(xsl)
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 18)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_tracing_nested_tensor_subclass(self, pytree):
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
def fn(xs):
|
||||
nested_xs = [[xs]]
|
||||
flat_xs, spec = pytree.tree_flatten(xs)
|
||||
return flat_xs[0].clone()
|
||||
|
||||
# use checkpoint to trigger a "sourceless" tensor subclass
|
||||
def checkpoint_fn(xs):
|
||||
return checkpoint(fn, xs, use_reentrant=True)
|
||||
|
||||
xs = TwoTensor(torch.ones(2, 2), torch.ones(2, 2))
|
||||
|
||||
counter = CompileCounter()
|
||||
torch.compile(checkpoint_fn, backend=counter, fullgraph=True)(xs)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 2)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_leaves(self, pytree):
|
||||
def fn(x):
|
||||
tree = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
leaves = pytree.tree_leaves(tree)
|
||||
return leaves
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
expected = fn(x)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_flatten_unflatten(self, pytree):
|
||||
def fn(x, y):
|
||||
tree = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
leaves, treespec = pytree.tree_flatten(tree)
|
||||
new_leaves = [
|
||||
x - 1,
|
||||
y,
|
||||
x * y,
|
||||
3.0,
|
||||
y - 2,
|
||||
1,
|
||||
torch.zeros(2, 2),
|
||||
2 * y,
|
||||
-y,
|
||||
x + y,
|
||||
x - y,
|
||||
torch.ones(3, 2),
|
||||
1,
|
||||
]
|
||||
new_tree = pytree.tree_unflatten(new_leaves, treespec)
|
||||
return leaves, new_tree
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
y = torch.randn(3, 2)
|
||||
expected = fn(x, y)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x, y)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map(self, pytree):
|
||||
def fn(x, y):
|
||||
tree1 = {
|
||||
"a": [x, x - 1],
|
||||
"b": x + 2,
|
||||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
"e": torch.return_types.qr((2 * x, None)),
|
||||
"f": MyTuple(x, x + 1, torch.zeros(4, 3)),
|
||||
},
|
||||
),
|
||||
}
|
||||
tree2 = collections.OrderedDict(
|
||||
[
|
||||
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
|
||||
("a", [y, y + 1]),
|
||||
("b", y + 2),
|
||||
(
|
||||
"d",
|
||||
{
|
||||
"f": MyTuple(torch.ones(4, 3), -y, y + 1),
|
||||
"e": torch.return_types.qr((2 * y, None)),
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
return pytree.tree_map(lambda u, v: (u, v), tree1, tree2)
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
y = torch.randn(3, 2)
|
||||
expected = fn(x, y)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
actual = fn_opt(x, y)
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
def fn(xs):
|
||||
def mapper(x):
|
||||
return x.clone()
|
||||
|
||||
y = pytree.tree_map_only(torch.Tensor, mapper, xs)
|
||||
return y
|
||||
|
||||
xs = [torch.tensor(i) for i in range(3)] + ["hi"]
|
||||
xsa = (xs, xs)
|
||||
xsb = {"aa": xsa, "ab": xs}
|
||||
|
||||
counter = CompileCounter()
|
||||
comp_out = torch.compile(fn, backend=counter, fullgraph=True)(xsb)
|
||||
real_out = fn(xsb)
|
||||
|
||||
self.assertEqual(comp_out, real_out)
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
def fn():
|
||||
@ -13266,10 +13270,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||
# RuntimeError: value cannot be converted to type at::Half without overflow
|
||||
|
||||
|
||||
instantiate_parametrized_tests(MiscTestsPyTree)
|
||||
|
||||
devices = ("cuda", "hpu", "xpu")
|
||||
instantiate_device_type_tests(
|
||||
MiscTestsDevice, globals(), only_for=devices, allow_xpu=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user